[SPARK-16447][ML][SPARKR] LDA wrapper in SparkR
## What changes were proposed in this pull request? Add LDA Wrapper in SparkR with the following interfaces: - spark.lda(data, ...) - spark.posterior(object, newData, ...) - spark.perplexity(object, ...) - summary(object) - write.ml(object) - read.ml(path) ## How was this patch tested? Test with SparkR unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #14229 from yinxusen/SPARK-16447.
This commit is contained in:
parent
68f5087d21
commit
b72bb62d42
|
@ -25,6 +25,9 @@ exportMethods("glm",
|
||||||
"fitted",
|
"fitted",
|
||||||
"spark.naiveBayes",
|
"spark.naiveBayes",
|
||||||
"spark.survreg",
|
"spark.survreg",
|
||||||
|
"spark.lda",
|
||||||
|
"spark.posterior",
|
||||||
|
"spark.perplexity",
|
||||||
"spark.isoreg",
|
"spark.isoreg",
|
||||||
"spark.gaussianMixture")
|
"spark.gaussianMixture")
|
||||||
|
|
||||||
|
|
|
@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
|
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
|
||||||
|
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @param ... Additional parameters to tune LDA.
|
||||||
|
#' @export
|
||||||
|
setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
|
||||||
|
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @export
|
||||||
|
setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
|
||||||
|
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @export
|
||||||
|
setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
|
||||||
|
|
||||||
#' @rdname spark.isoreg
|
#' @rdname spark.isoreg
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
||||||
|
@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture",
|
||||||
standardGeneric("spark.gaussianMixture")
|
standardGeneric("spark.gaussianMixture")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' write.ml
|
||||||
#' @rdname write.ml
|
#' @rdname write.ml
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
|
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
|
||||||
|
|
166
R/pkg/R/mllib.R
166
R/pkg/R/mllib.R
|
@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
|
||||||
#' @note NaiveBayesModel since 2.0.0
|
#' @note NaiveBayesModel since 2.0.0
|
||||||
setClass("NaiveBayesModel", representation(jobj = "jobj"))
|
setClass("NaiveBayesModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
|
#' S4 class that represents an LDAModel
|
||||||
|
#'
|
||||||
|
#' @param jobj a Java object reference to the backing Scala LDAWrapper
|
||||||
|
#' @export
|
||||||
|
#' @note LDAModel since 2.1.0
|
||||||
|
setClass("LDAModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
#' S4 class that represents a AFTSurvivalRegressionModel
|
#' S4 class that represents a AFTSurvivalRegressionModel
|
||||||
#'
|
#'
|
||||||
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
|
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
|
||||||
|
@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
|
||||||
#' @name write.ml
|
#' @name write.ml
|
||||||
#' @export
|
#' @export
|
||||||
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
|
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
|
||||||
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
|
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
|
||||||
#' @seealso \link{spark.isoreg}
|
#' @seealso \link{spark.isoreg}
|
||||||
#' @seealso \link{read.ml}
|
#' @seealso \link{read.ml}
|
||||||
NULL
|
NULL
|
||||||
|
@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
|
||||||
return(list(apriori = apriori, tables = tables))
|
return(list(apriori = apriori, tables = tables))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
|
||||||
|
|
||||||
|
#' @param newData A SparkDataFrame for testing
|
||||||
|
#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities
|
||||||
|
#' vectors named "topicDistribution"
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @aliases spark.posterior,LDAModel,SparkDataFrame-method
|
||||||
|
#' @export
|
||||||
|
#' @note spark.posterior(LDAModel) since 2.1.0
|
||||||
|
setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
|
||||||
|
function(object, newData) {
|
||||||
|
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
|
||||||
|
|
||||||
|
#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}.
|
||||||
|
#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10.
|
||||||
|
#' @return \code{summary} returns a list containing
|
||||||
|
#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for
|
||||||
|
#' the prior placed on documents distributions over topics \code{theta}}
|
||||||
|
#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or
|
||||||
|
#' \code{eta} for the prior placed on topic distributions over terms}
|
||||||
|
#' \item{\code{logLikelihood}}{log likelihood of the entire corpus}
|
||||||
|
#' \item{\code{logPerplexity}}{log perplexity}
|
||||||
|
#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model}
|
||||||
|
#' \item{\code{vocabSize}}{number of terms in the corpus}
|
||||||
|
#' \item{\code{topics}}{top 10 terms and their weights of all topics}
|
||||||
|
#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
|
||||||
|
#' used as training set}
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @aliases summary,LDAModel-method
|
||||||
|
#' @export
|
||||||
|
#' @note summary(LDAModel) since 2.1.0
|
||||||
|
setMethod("summary", signature(object = "LDAModel"),
|
||||||
|
function(object, maxTermsPerTopic) {
|
||||||
|
maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic))
|
||||||
|
jobj <- object@jobj
|
||||||
|
docConcentration <- callJMethod(jobj, "docConcentration")
|
||||||
|
topicConcentration <- callJMethod(jobj, "topicConcentration")
|
||||||
|
logLikelihood <- callJMethod(jobj, "logLikelihood")
|
||||||
|
logPerplexity <- callJMethod(jobj, "logPerplexity")
|
||||||
|
isDistributed <- callJMethod(jobj, "isDistributed")
|
||||||
|
vocabSize <- callJMethod(jobj, "vocabSize")
|
||||||
|
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
|
||||||
|
vocabulary <- callJMethod(jobj, "vocabulary")
|
||||||
|
return(list(docConcentration = unlist(docConcentration),
|
||||||
|
topicConcentration = topicConcentration,
|
||||||
|
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
|
||||||
|
isDistributed = isDistributed, vocabSize = vocabSize,
|
||||||
|
topics = topics,
|
||||||
|
vocabulary = unlist(vocabulary)))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
|
||||||
|
|
||||||
|
#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log
|
||||||
|
#' perplexity of the training data if missing argument "data".
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @aliases spark.perplexity,LDAModel-method
|
||||||
|
#' @export
|
||||||
|
#' @note spark.perplexity(LDAModel) since 2.1.0
|
||||||
|
setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
|
||||||
|
function(object, data) {
|
||||||
|
return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
|
||||||
|
callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Saves the Latent Dirichlet Allocation model to the input path.
|
||||||
|
|
||||||
|
#' @param path The directory where the model is saved
|
||||||
|
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||||
|
#' which means throw exception if the output path exists.
|
||||||
|
#'
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @aliases write.ml,LDAModel,character-method
|
||||||
|
#' @export
|
||||||
|
#' @seealso \link{read.ml}
|
||||||
|
#' @note write.ml(LDAModel, character) since 2.1.0
|
||||||
|
setMethod("write.ml", signature(object = "LDAModel", path = "character"),
|
||||||
|
function(object, path, overwrite = FALSE) {
|
||||||
|
writer <- callJMethod(object@jobj, "write")
|
||||||
|
if (overwrite) {
|
||||||
|
writer <- callJMethod(writer, "overwrite")
|
||||||
|
}
|
||||||
|
invisible(callJMethod(writer, "save", path))
|
||||||
|
})
|
||||||
|
|
||||||
#' Isotonic Regression Model
|
#' Isotonic Regression Model
|
||||||
#'
|
#'
|
||||||
#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
|
#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
|
||||||
|
@ -700,6 +795,8 @@ read.ml <- function(path) {
|
||||||
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
|
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
|
||||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
|
||||||
return(new("KMeansModel", jobj = jobj))
|
return(new("KMeansModel", jobj = jobj))
|
||||||
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
|
||||||
|
return(new("LDAModel", jobj = jobj))
|
||||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
|
||||||
return(new("IsotonicRegressionModel", jobj = jobj))
|
return(new("IsotonicRegressionModel", jobj = jobj))
|
||||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
|
||||||
|
@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
|
||||||
return(new("AFTSurvivalRegressionModel", jobj = jobj))
|
return(new("AFTSurvivalRegressionModel", jobj = jobj))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' Latent Dirichlet Allocation
|
||||||
|
#'
|
||||||
|
#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call
|
||||||
|
#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute
|
||||||
|
#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new
|
||||||
|
#' data and \code{write.ml}/\code{read.ml} to save/load fitted models.
|
||||||
|
#'
|
||||||
|
#' @param data A SparkDataFrame for training
|
||||||
|
#' @param features Features column name, default "features". Either libSVM-format column or
|
||||||
|
#' character-format column is valid.
|
||||||
|
#' @param k Number of topics, default 10
|
||||||
|
#' @param maxIter Maximum iterations, default 20
|
||||||
|
#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online"
|
||||||
|
#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in
|
||||||
|
#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05
|
||||||
|
#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for
|
||||||
|
#' the prior placed on topic distributions over terms, default -1 to set automatically on the
|
||||||
|
#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size
|
||||||
|
#' numeric is accepted.
|
||||||
|
#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the
|
||||||
|
#' prior placed on documents distributions over topics (\code{theta}), default -1 to set
|
||||||
|
#' automatically on the Spark side. Use \code{summary} to retrieve the effective
|
||||||
|
#' docConcentration. Only 1-size or \code{k}-size numeric is accepted.
|
||||||
|
#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
|
||||||
|
#' parameter if libSVM-format column is used as the features column.
|
||||||
|
#' @param maxVocabSize maximum vocabulary size, default 1 << 18
|
||||||
|
#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
|
||||||
|
#' @rdname spark.lda
|
||||||
|
#' @aliases spark.lda,SparkDataFrame-method
|
||||||
|
#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' \dontrun{
|
||||||
|
#' text <- read.df("path/to/data", source = "libsvm")
|
||||||
|
#' model <- spark.lda(data = text, optimizer = "em")
|
||||||
|
#'
|
||||||
|
#' # get a summary of the model
|
||||||
|
#' summary(model)
|
||||||
|
#'
|
||||||
|
#' # compute posterior probabilities
|
||||||
|
#' posterior <- spark.posterior(model, df)
|
||||||
|
#' showDF(posterior)
|
||||||
|
#'
|
||||||
|
#' # compute perplexity
|
||||||
|
#' perplexity <- spark.perplexity(model, df)
|
||||||
|
#'
|
||||||
|
#' # save and load the model
|
||||||
|
#' path <- "path/to/model"
|
||||||
|
#' write.ml(model, path)
|
||||||
|
#' savedModel <- read.ml(path)
|
||||||
|
#' summary(savedModel)
|
||||||
|
#' }
|
||||||
|
#' @note spark.lda since 2.1.0
|
||||||
|
setMethod("spark.lda", signature(data = "SparkDataFrame"),
|
||||||
|
function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"),
|
||||||
|
subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1,
|
||||||
|
customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) {
|
||||||
|
optimizer <- match.arg(optimizer)
|
||||||
|
jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features,
|
||||||
|
as.integer(k), as.integer(maxIter), optimizer,
|
||||||
|
as.numeric(subsamplingRate), topicConcentration,
|
||||||
|
as.array(docConcentration), as.array(customizedStopWords),
|
||||||
|
maxVocabSize)
|
||||||
|
return(new("LDAModel", jobj = jobj))
|
||||||
|
})
|
||||||
|
|
||||||
# Returns a summary of the AFT survival regression model produced by spark.survreg,
|
# Returns a summary of the AFT survival regression model produced by spark.survreg,
|
||||||
# similarly to R's summary().
|
# similarly to R's summary().
|
||||||
|
@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
|
||||||
setMethod("predict", signature(object = "GaussianMixtureModel"),
|
setMethod("predict", signature(object = "GaussianMixtureModel"),
|
||||||
function(object, newData) {
|
function(object, newData) {
|
||||||
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
||||||
})
|
})
|
|
@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
|
||||||
unlink(modelPath)
|
unlink(modelPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("spark.lda with libsvm", {
|
||||||
|
text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
|
||||||
|
model <- spark.lda(text, optimizer = "em")
|
||||||
|
|
||||||
|
stats <- summary(model, 10)
|
||||||
|
isDistributed <- stats$isDistributed
|
||||||
|
logLikelihood <- stats$logLikelihood
|
||||||
|
logPerplexity <- stats$logPerplexity
|
||||||
|
vocabSize <- stats$vocabSize
|
||||||
|
topics <- stats$topicTopTerms
|
||||||
|
weights <- stats$topicTopTermsWeights
|
||||||
|
vocabulary <- stats$vocabulary
|
||||||
|
|
||||||
|
expect_false(isDistributed)
|
||||||
|
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
|
||||||
|
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
|
||||||
|
expect_equal(vocabSize, 11)
|
||||||
|
expect_true(is.null(vocabulary))
|
||||||
|
|
||||||
|
# Test model save/load
|
||||||
|
modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
|
||||||
|
write.ml(model, modelPath)
|
||||||
|
expect_error(write.ml(model, modelPath))
|
||||||
|
write.ml(model, modelPath, overwrite = TRUE)
|
||||||
|
model2 <- read.ml(modelPath)
|
||||||
|
stats2 <- summary(model2)
|
||||||
|
|
||||||
|
expect_false(stats2$isDistributed)
|
||||||
|
expect_equal(logLikelihood, stats2$logLikelihood)
|
||||||
|
expect_equal(logPerplexity, stats2$logPerplexity)
|
||||||
|
expect_equal(vocabSize, stats2$vocabSize)
|
||||||
|
expect_equal(vocabulary, stats2$vocabulary)
|
||||||
|
|
||||||
|
unlink(modelPath)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("spark.lda with text input", {
|
||||||
|
text <- read.text("data/mllib/sample_lda_data.txt")
|
||||||
|
model <- spark.lda(text, optimizer = "online", features = "value")
|
||||||
|
|
||||||
|
stats <- summary(model)
|
||||||
|
isDistributed <- stats$isDistributed
|
||||||
|
logLikelihood <- stats$logLikelihood
|
||||||
|
logPerplexity <- stats$logPerplexity
|
||||||
|
vocabSize <- stats$vocabSize
|
||||||
|
topics <- stats$topicTopTerms
|
||||||
|
weights <- stats$topicTopTermsWeights
|
||||||
|
vocabulary <- stats$vocabulary
|
||||||
|
|
||||||
|
expect_false(isDistributed)
|
||||||
|
expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
|
||||||
|
expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
|
||||||
|
expect_equal(vocabSize, 10)
|
||||||
|
expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
|
||||||
|
|
||||||
|
# Test model save/load
|
||||||
|
modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
|
||||||
|
write.ml(model, modelPath)
|
||||||
|
expect_error(write.ml(model, modelPath))
|
||||||
|
write.ml(model, modelPath, overwrite = TRUE)
|
||||||
|
model2 <- read.ml(modelPath)
|
||||||
|
stats2 <- summary(model2)
|
||||||
|
|
||||||
|
expect_false(stats2$isDistributed)
|
||||||
|
expect_equal(logLikelihood, stats2$logLikelihood)
|
||||||
|
expect_equal(logPerplexity, stats2$logPerplexity)
|
||||||
|
expect_equal(vocabSize, stats2$vocabSize)
|
||||||
|
expect_true(all.equal(vocabulary, stats2$vocabulary))
|
||||||
|
|
||||||
|
unlink(modelPath)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("spark.posterior and spark.perplexity", {
|
||||||
|
text <- read.text("data/mllib/sample_lda_data.txt")
|
||||||
|
model <- spark.lda(text, features = "value", k = 3)
|
||||||
|
|
||||||
|
# Assert perplexities are equal
|
||||||
|
stats <- summary(model)
|
||||||
|
logPerplexity <- spark.perplexity(model, text)
|
||||||
|
expect_equal(logPerplexity, stats$logPerplexity)
|
||||||
|
|
||||||
|
# Assert the sum of every topic distribution is equal to 1
|
||||||
|
posterior <- spark.posterior(model, text)
|
||||||
|
local.posterior <- collect(posterior)$topicDistribution
|
||||||
|
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
|
||||||
|
})
|
||||||
|
|
||||||
sparkR.session.stop()
|
sparkR.session.stop()
|
||||||
|
|
|
@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
protected def getModel: OldLDAModel
|
protected def getModel: OldLDAModel
|
||||||
|
|
||||||
|
private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
|
||||||
|
|
||||||
|
private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The features for LDA should be a [[Vector]] representing the word counts in a document.
|
* The features for LDA should be a [[Vector]] representing the word counts in a document.
|
||||||
* The vector should be of length vocabSize, with counts for each term (word).
|
* The vector should be of length vocabSize, with counts for each term (word).
|
||||||
|
|
216
mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
Normal file
216
mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
Normal file
|
@ -0,0 +1,216 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.r
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.json4s._
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
|
import org.apache.spark.SparkException
|
||||||
|
import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
|
||||||
|
import org.apache.spark.ml.clustering.{LDA, LDAModel}
|
||||||
|
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
|
||||||
|
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
||||||
|
import org.apache.spark.ml.param.ParamPair
|
||||||
|
import org.apache.spark.ml.util._
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types.StringType
|
||||||
|
|
||||||
|
|
||||||
|
private[r] class LDAWrapper private (
|
||||||
|
val pipeline: PipelineModel,
|
||||||
|
val logLikelihood: Double,
|
||||||
|
val logPerplexity: Double,
|
||||||
|
val vocabulary: Array[String]) extends MLWritable {
|
||||||
|
|
||||||
|
import LDAWrapper._
|
||||||
|
|
||||||
|
private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
|
||||||
|
private val preprocessor: PipelineModel =
|
||||||
|
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
|
||||||
|
|
||||||
|
def transform(data: Dataset[_]): DataFrame = {
|
||||||
|
val vec2ary = udf { vec: Vector => vec.toArray }
|
||||||
|
val outputCol = lda.getTopicDistributionCol
|
||||||
|
val tempCol = s"${Identifiable.randomUID(outputCol)}"
|
||||||
|
val preprocessed = preprocessor.transform(data)
|
||||||
|
lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
|
||||||
|
.withColumn(outputCol, vec2ary(col(tempCol)))
|
||||||
|
.drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
def computeLogPerplexity(data: Dataset[_]): Double = {
|
||||||
|
lda.logPerplexity(preprocessor.transform(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
def topics(maxTermsPerTopic: Int): DataFrame = {
|
||||||
|
val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
|
||||||
|
if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
|
||||||
|
topicIndices
|
||||||
|
} else {
|
||||||
|
val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
|
||||||
|
topicIndices
|
||||||
|
.select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy val isDistributed: Boolean = lda.isDistributed
|
||||||
|
lazy val vocabSize: Int = lda.vocabSize
|
||||||
|
lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
|
||||||
|
lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
|
||||||
|
|
||||||
|
override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
|
||||||
|
|
||||||
|
val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
|
||||||
|
val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
|
||||||
|
val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
|
||||||
|
|
||||||
|
private def getPreStages(
|
||||||
|
features: String,
|
||||||
|
customizedStopWords: Array[String],
|
||||||
|
maxVocabSize: Int): Array[PipelineStage] = {
|
||||||
|
val tokenizer = new RegexTokenizer()
|
||||||
|
.setInputCol(features)
|
||||||
|
.setOutputCol(TOKENIZER_COL)
|
||||||
|
val stopWordsRemover = new StopWordsRemover()
|
||||||
|
.setInputCol(TOKENIZER_COL)
|
||||||
|
.setOutputCol(STOPWORDS_REMOVER_COL)
|
||||||
|
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
|
||||||
|
val countVectorizer = new CountVectorizer()
|
||||||
|
.setVocabSize(maxVocabSize)
|
||||||
|
.setInputCol(STOPWORDS_REMOVER_COL)
|
||||||
|
.setOutputCol(COUNT_VECTOR_COL)
|
||||||
|
|
||||||
|
Array(tokenizer, stopWordsRemover, countVectorizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
data: DataFrame,
|
||||||
|
features: String,
|
||||||
|
k: Int,
|
||||||
|
maxIter: Int,
|
||||||
|
optimizer: String,
|
||||||
|
subsamplingRate: Double,
|
||||||
|
topicConcentration: Double,
|
||||||
|
docConcentration: Array[Double],
|
||||||
|
customizedStopWords: Array[String],
|
||||||
|
maxVocabSize: Int): LDAWrapper = {
|
||||||
|
|
||||||
|
val lda = new LDA()
|
||||||
|
.setK(k)
|
||||||
|
.setMaxIter(maxIter)
|
||||||
|
.setSubsamplingRate(subsamplingRate)
|
||||||
|
|
||||||
|
val featureSchema = data.schema(features)
|
||||||
|
val stages = featureSchema.dataType match {
|
||||||
|
case d: StringType =>
|
||||||
|
getPreStages(features, customizedStopWords, maxVocabSize) ++
|
||||||
|
Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
|
||||||
|
case d: VectorUDT =>
|
||||||
|
Array(lda.setFeaturesCol(features))
|
||||||
|
case _ =>
|
||||||
|
throw new SparkException(
|
||||||
|
s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
|
||||||
|
s" only String type and Vector type are supported now.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (topicConcentration != -1) {
|
||||||
|
lda.setTopicConcentration(topicConcentration)
|
||||||
|
} else {
|
||||||
|
// Auto-set topicConcentration
|
||||||
|
}
|
||||||
|
|
||||||
|
if (docConcentration.length == 1) {
|
||||||
|
if (docConcentration.head != -1) {
|
||||||
|
lda.setDocConcentration(docConcentration.head)
|
||||||
|
} else {
|
||||||
|
// Auto-set docConcentration
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lda.setDocConcentration(docConcentration)
|
||||||
|
}
|
||||||
|
|
||||||
|
val pipeline = new Pipeline().setStages(stages)
|
||||||
|
val model = pipeline.fit(data)
|
||||||
|
|
||||||
|
val vocabulary: Array[String] = featureSchema.dataType match {
|
||||||
|
case d: StringType =>
|
||||||
|
val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
|
||||||
|
countVectorModel.vocabulary
|
||||||
|
case _ => Array.empty[String]
|
||||||
|
}
|
||||||
|
|
||||||
|
val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
|
||||||
|
val preprocessor: PipelineModel =
|
||||||
|
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
|
||||||
|
|
||||||
|
val preprocessedData = preprocessor.transform(data)
|
||||||
|
|
||||||
|
new LDAWrapper(
|
||||||
|
model,
|
||||||
|
ldaModel.logLikelihood(preprocessedData),
|
||||||
|
ldaModel.logPerplexity(preprocessedData),
|
||||||
|
vocabulary)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def read: MLReader[LDAWrapper] = new LDAWrapperReader
|
||||||
|
|
||||||
|
override def load(path: String): LDAWrapper = super.load(path)
|
||||||
|
|
||||||
|
class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
|
||||||
|
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
|
||||||
|
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||||
|
("logLikelihood" -> instance.logLikelihood) ~
|
||||||
|
("logPerplexity" -> instance.logPerplexity) ~
|
||||||
|
("vocabulary" -> instance.vocabulary.toList)
|
||||||
|
val rMetadataJson: String = compact(render(rMetadata))
|
||||||
|
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||||
|
|
||||||
|
instance.pipeline.save(pipelinePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class LDAWrapperReader extends MLReader[LDAWrapper] {
|
||||||
|
|
||||||
|
override def load(path: String): LDAWrapper = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
|
||||||
|
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||||
|
val rMetadata = parse(rMetadataStr)
|
||||||
|
val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
|
||||||
|
val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
|
||||||
|
val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
|
||||||
|
|
||||||
|
val pipeline = PipelineModel.load(pipelinePath)
|
||||||
|
new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
|
||||||
GeneralizedLinearRegressionWrapper.load(path)
|
GeneralizedLinearRegressionWrapper.load(path)
|
||||||
case "org.apache.spark.ml.r.KMeansWrapper" =>
|
case "org.apache.spark.ml.r.KMeansWrapper" =>
|
||||||
KMeansWrapper.load(path)
|
KMeansWrapper.load(path)
|
||||||
|
case "org.apache.spark.ml.r.LDAWrapper" =>
|
||||||
|
LDAWrapper.load(path)
|
||||||
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
|
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
|
||||||
IsotonicRegressionWrapper.load(path)
|
IsotonicRegressionWrapper.load(path)
|
||||||
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
|
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
|
||||||
|
|
Loading…
Reference in a new issue