[SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR
## What changes were proposed in this pull request? Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14392 from yanboliang/spark-16446.
This commit is contained in:
parent
e3fec51fa1
commit
4d92af310a
|
@ -25,7 +25,8 @@ exportMethods("glm",
|
||||||
"fitted",
|
"fitted",
|
||||||
"spark.naiveBayes",
|
"spark.naiveBayes",
|
||||||
"spark.survreg",
|
"spark.survreg",
|
||||||
"spark.isoreg")
|
"spark.isoreg",
|
||||||
|
"spark.gaussianMixture")
|
||||||
|
|
||||||
# Job group lifecycle management methods
|
# Job group lifecycle management methods
|
||||||
export("setJobGroup",
|
export("setJobGroup",
|
||||||
|
|
|
@ -1308,6 +1308,13 @@ setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spar
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
||||||
|
|
||||||
|
#' @rdname spark.gaussianMixture
|
||||||
|
#' @export
|
||||||
|
setGeneric("spark.gaussianMixture",
|
||||||
|
function(data, formula, ...) {
|
||||||
|
standardGeneric("spark.gaussianMixture")
|
||||||
|
})
|
||||||
|
|
||||||
#' @rdname write.ml
|
#' @rdname write.ml
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
|
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
|
||||||
|
|
139
R/pkg/R/mllib.R
139
R/pkg/R/mllib.R
|
@ -60,6 +60,13 @@ setClass("KMeansModel", representation(jobj = "jobj"))
|
||||||
#' @note IsotonicRegressionModel since 2.1.0
|
#' @note IsotonicRegressionModel since 2.1.0
|
||||||
setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
|
setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
|
#' S4 class that represents a GaussianMixtureModel
|
||||||
|
#'
|
||||||
|
#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel
|
||||||
|
#' @export
|
||||||
|
#' @note GaussianMixtureModel since 2.1.0
|
||||||
|
setClass("GaussianMixtureModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
#' Saves the MLlib model to the input path
|
#' Saves the MLlib model to the input path
|
||||||
#'
|
#'
|
||||||
#' Saves the MLlib model to the input path. For more information, see the specific
|
#' Saves the MLlib model to the input path. For more information, see the specific
|
||||||
|
@ -67,7 +74,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
|
||||||
#' @rdname write.ml
|
#' @rdname write.ml
|
||||||
#' @name write.ml
|
#' @name write.ml
|
||||||
#' @export
|
#' @export
|
||||||
#' @seealso \link{spark.glm}, \link{glm}
|
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
|
||||||
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
|
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
|
||||||
#' @seealso \link{spark.isoreg}
|
#' @seealso \link{spark.isoreg}
|
||||||
#' @seealso \link{read.ml}
|
#' @seealso \link{read.ml}
|
||||||
|
@ -80,7 +87,7 @@ NULL
|
||||||
#' @rdname predict
|
#' @rdname predict
|
||||||
#' @name predict
|
#' @name predict
|
||||||
#' @export
|
#' @export
|
||||||
#' @seealso \link{spark.glm}, \link{glm}
|
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
|
||||||
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
|
#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
|
||||||
#' @seealso \link{spark.isoreg}
|
#' @seealso \link{spark.isoreg}
|
||||||
NULL
|
NULL
|
||||||
|
@ -649,6 +656,25 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
|
||||||
invisible(callJMethod(writer, "save", path))
|
invisible(callJMethod(writer, "save", path))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Save fitted MLlib model to the input path
|
||||||
|
|
||||||
|
#' @param path the directory where the model is saved.
|
||||||
|
#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
|
||||||
|
#' which means throw exception if the output path exists.
|
||||||
|
#'
|
||||||
|
#' @aliases write.ml,GaussianMixtureModel,character-method
|
||||||
|
#' @rdname spark.gaussianMixture
|
||||||
|
#' @export
|
||||||
|
#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
|
||||||
|
setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
|
||||||
|
function(object, path, overwrite = FALSE) {
|
||||||
|
writer <- callJMethod(object@jobj, "write")
|
||||||
|
if (overwrite) {
|
||||||
|
writer <- callJMethod(writer, "overwrite")
|
||||||
|
}
|
||||||
|
invisible(callJMethod(writer, "save", path))
|
||||||
|
})
|
||||||
|
|
||||||
#' Load a fitted MLlib model from the input path.
|
#' Load a fitted MLlib model from the input path.
|
||||||
#'
|
#'
|
||||||
#' @param path Path of the model to read.
|
#' @param path Path of the model to read.
|
||||||
|
@ -676,6 +702,8 @@ read.ml <- function(path) {
|
||||||
return(new("KMeansModel", jobj = jobj))
|
return(new("KMeansModel", jobj = jobj))
|
||||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
|
||||||
return(new("IsotonicRegressionModel", jobj = jobj))
|
return(new("IsotonicRegressionModel", jobj = jobj))
|
||||||
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
|
||||||
|
return(new("GaussianMixtureModel", jobj = jobj))
|
||||||
} else {
|
} else {
|
||||||
stop(paste("Unsupported model: ", jobj))
|
stop(paste("Unsupported model: ", jobj))
|
||||||
}
|
}
|
||||||
|
@ -757,3 +785,110 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
|
||||||
function(object, newData) {
|
function(object, newData) {
|
||||||
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' Multivariate Gaussian Mixture Model (GMM)
|
||||||
|
#'
|
||||||
|
#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's
|
||||||
|
#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model,
|
||||||
|
#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml}
|
||||||
|
#' to save/load fitted models.
|
||||||
|
#'
|
||||||
|
#' @param data a SparkDataFrame for training.
|
||||||
|
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
|
||||||
|
#' operators are supported, including '~', '.', ':', '+', and '-'.
|
||||||
|
#' Note that the response variable of formula is empty in spark.gaussianMixture.
|
||||||
|
#' @param k number of independent Gaussians in the mixture model.
|
||||||
|
#' @param maxIter maximum iteration number.
|
||||||
|
#' @param tol the convergence tolerance.
|
||||||
|
#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
|
||||||
|
#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model.
|
||||||
|
#' @rdname spark.gaussianMixture
|
||||||
|
#' @name spark.gaussianMixture
|
||||||
|
#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/}
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#' \dontrun{
|
||||||
|
#' sparkR.session()
|
||||||
|
#' library(mvtnorm)
|
||||||
|
#' set.seed(100)
|
||||||
|
#' a <- rmvnorm(4, c(0, 0))
|
||||||
|
#' b <- rmvnorm(6, c(3, 4))
|
||||||
|
#' data <- rbind(a, b)
|
||||||
|
#' df <- createDataFrame(as.data.frame(data))
|
||||||
|
#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2)
|
||||||
|
#' summary(model)
|
||||||
|
#'
|
||||||
|
#' # fitted values on training data
|
||||||
|
#' fitted <- predict(model, df)
|
||||||
|
#' head(select(fitted, "V1", "prediction"))
|
||||||
|
#'
|
||||||
|
#' # save fitted model to input path
|
||||||
|
#' path <- "path/to/model"
|
||||||
|
#' write.ml(model, path)
|
||||||
|
#'
|
||||||
|
#' # can also read back the saved model and print
|
||||||
|
#' savedModel <- read.ml(path)
|
||||||
|
#' summary(savedModel)
|
||||||
|
#' }
|
||||||
|
#' @note spark.gaussianMixture since 2.1.0
|
||||||
|
#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
|
||||||
|
setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"),
|
||||||
|
function(data, formula, k = 2, maxIter = 100, tol = 0.01) {
|
||||||
|
formula <- paste(deparse(formula), collapse = "")
|
||||||
|
jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf,
|
||||||
|
formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
|
||||||
|
return(new("GaussianMixtureModel", jobj = jobj))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get the summary of a multivariate gaussian mixture model
|
||||||
|
|
||||||
|
#' @param object a fitted gaussian mixture model.
|
||||||
|
#' @param ... currently not used argument(s) passed to the method.
|
||||||
|
#' @return \code{summary} returns the model's lambda, mu, sigma and posterior.
|
||||||
|
#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
|
||||||
|
#' @rdname spark.gaussianMixture
|
||||||
|
#' @export
|
||||||
|
#' @note summary(GaussianMixtureModel) since 2.1.0
|
||||||
|
setMethod("summary", signature(object = "GaussianMixtureModel"),
|
||||||
|
function(object, ...) {
|
||||||
|
jobj <- object@jobj
|
||||||
|
is.loaded <- callJMethod(jobj, "isLoaded")
|
||||||
|
lambda <- unlist(callJMethod(jobj, "lambda"))
|
||||||
|
muList <- callJMethod(jobj, "mu")
|
||||||
|
sigmaList <- callJMethod(jobj, "sigma")
|
||||||
|
k <- callJMethod(jobj, "k")
|
||||||
|
dim <- callJMethod(jobj, "dim")
|
||||||
|
mu <- c()
|
||||||
|
for (i in 1 : k) {
|
||||||
|
start <- (i - 1) * dim + 1
|
||||||
|
end <- i * dim
|
||||||
|
mu[[i]] <- unlist(muList[start : end])
|
||||||
|
}
|
||||||
|
sigma <- c()
|
||||||
|
for (i in 1 : k) {
|
||||||
|
start <- (i - 1) * dim * dim + 1
|
||||||
|
end <- i * dim * dim
|
||||||
|
sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim))
|
||||||
|
}
|
||||||
|
posterior <- if (is.loaded) {
|
||||||
|
NULL
|
||||||
|
} else {
|
||||||
|
dataFrame(callJMethod(jobj, "posterior"))
|
||||||
|
}
|
||||||
|
return(list(lambda = lambda, mu = mu, sigma = sigma,
|
||||||
|
posterior = posterior, is.loaded = is.loaded))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Predicted values based on a gaussian mixture model
|
||||||
|
|
||||||
|
#' @param newData a SparkDataFrame for testing.
|
||||||
|
#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named
|
||||||
|
#' "prediction".
|
||||||
|
#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method
|
||||||
|
#' @rdname spark.gaussianMixture
|
||||||
|
#' @export
|
||||||
|
#' @note predict(GaussianMixtureModel) since 2.1.0
|
||||||
|
setMethod("predict", signature(object = "GaussianMixtureModel"),
|
||||||
|
function(object, newData) {
|
||||||
|
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
|
||||||
|
})
|
||||||
|
|
|
@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", {
|
||||||
unlink(modelPath)
|
unlink(modelPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("spark.gaussianMixture", {
|
||||||
|
# R code to reproduce the result.
|
||||||
|
# nolint start
|
||||||
|
#' library(mvtnorm)
|
||||||
|
#' set.seed(100)
|
||||||
|
#' a <- rmvnorm(4, c(0, 0))
|
||||||
|
#' b <- rmvnorm(6, c(3, 4))
|
||||||
|
#' data <- rbind(a, b)
|
||||||
|
#' model <- mvnormalmixEM(data, k = 2)
|
||||||
|
#' model$lambda
|
||||||
|
#
|
||||||
|
# [1] 0.4 0.6
|
||||||
|
#
|
||||||
|
#' model$mu
|
||||||
|
#
|
||||||
|
# [1] -0.2614822 0.5128697
|
||||||
|
# [1] 2.647284 4.544682
|
||||||
|
#
|
||||||
|
#' model$sigma
|
||||||
|
#
|
||||||
|
# [[1]]
|
||||||
|
# [,1] [,2]
|
||||||
|
# [1,] 0.08427399 0.00548772
|
||||||
|
# [2,] 0.00548772 0.09090715
|
||||||
|
#
|
||||||
|
# [[2]]
|
||||||
|
# [,1] [,2]
|
||||||
|
# [1,] 0.1641373 -0.1673806
|
||||||
|
# [2,] -0.1673806 0.7508951
|
||||||
|
# nolint end
|
||||||
|
data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
|
||||||
|
list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
|
||||||
|
list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
|
||||||
|
list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
|
||||||
|
list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
|
||||||
|
df <- createDataFrame(data, c("x1", "x2"))
|
||||||
|
model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
|
||||||
|
stats <- summary(model)
|
||||||
|
rLambda <- c(0.4, 0.6)
|
||||||
|
rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
|
||||||
|
rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
|
||||||
|
0.1641373, -0.1673806, -0.1673806, 0.7508951)
|
||||||
|
expect_equal(stats$lambda, rLambda)
|
||||||
|
expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
|
||||||
|
expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
|
||||||
|
p <- collect(select(predict(model, df), "prediction"))
|
||||||
|
expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
|
||||||
|
|
||||||
|
# Test model save/load
|
||||||
|
modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
|
||||||
|
write.ml(model, modelPath)
|
||||||
|
expect_error(write.ml(model, modelPath))
|
||||||
|
write.ml(model, modelPath, overwrite = TRUE)
|
||||||
|
model2 <- read.ml(modelPath)
|
||||||
|
stats2 <- summary(model2)
|
||||||
|
expect_equal(stats$lambda, stats2$lambda)
|
||||||
|
expect_equal(unlist(stats$mu), unlist(stats2$mu))
|
||||||
|
expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
|
||||||
|
|
||||||
|
unlink(modelPath)
|
||||||
|
})
|
||||||
|
|
||||||
sparkR.session.stop()
|
sparkR.session.stop()
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.r
|
||||||
|
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.json4s._
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
|
import org.apache.spark.ml.clustering.{GaussianMixture, GaussianMixtureModel}
|
||||||
|
import org.apache.spark.ml.feature.RFormula
|
||||||
|
import org.apache.spark.ml.linalg.Vector
|
||||||
|
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
private[r] class GaussianMixtureWrapper private (
|
||||||
|
val pipeline: PipelineModel,
|
||||||
|
val dim: Int,
|
||||||
|
val isLoaded: Boolean = false) extends MLWritable {
|
||||||
|
|
||||||
|
private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
|
||||||
|
|
||||||
|
lazy val k: Int = gmm.getK
|
||||||
|
|
||||||
|
lazy val lambda: Array[Double] = gmm.weights
|
||||||
|
|
||||||
|
lazy val mu: Array[Double] = gmm.gaussians.flatMap(_.mean.toArray)
|
||||||
|
|
||||||
|
lazy val sigma: Array[Double] = gmm.gaussians.flatMap(_.cov.toArray)
|
||||||
|
|
||||||
|
lazy val vectorToArray = udf { probability: Vector => probability.toArray }
|
||||||
|
lazy val posterior: DataFrame = gmm.summary.probability
|
||||||
|
.withColumn("posterior", vectorToArray(col(gmm.summary.probabilityCol)))
|
||||||
|
.drop(gmm.summary.probabilityCol)
|
||||||
|
|
||||||
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
|
pipeline.transform(dataset).drop(gmm.getFeaturesCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def write: MLWriter = new GaussianMixtureWrapper.GaussianMixtureWrapperWriter(this)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapper] {
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
data: DataFrame,
|
||||||
|
formula: String,
|
||||||
|
k: Int,
|
||||||
|
maxIter: Int,
|
||||||
|
tol: Double): GaussianMixtureWrapper = {
|
||||||
|
|
||||||
|
val rFormulaModel = new RFormula()
|
||||||
|
.setFormula(formula)
|
||||||
|
.setFeaturesCol("features")
|
||||||
|
.fit(data)
|
||||||
|
|
||||||
|
// get feature names from output schema
|
||||||
|
val schema = rFormulaModel.transform(data).schema
|
||||||
|
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||||
|
.attributes.get
|
||||||
|
val features = featureAttrs.map(_.name.get)
|
||||||
|
val dim = features.length
|
||||||
|
|
||||||
|
val gm = new GaussianMixture()
|
||||||
|
.setK(k)
|
||||||
|
.setMaxIter(maxIter)
|
||||||
|
.setTol(tol)
|
||||||
|
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(rFormulaModel, gm))
|
||||||
|
.fit(data)
|
||||||
|
|
||||||
|
new GaussianMixtureWrapper(pipeline, dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
|
||||||
|
|
||||||
|
override def load(path: String): GaussianMixtureWrapper = super.load(path)
|
||||||
|
|
||||||
|
class GaussianMixtureWrapperWriter(instance: GaussianMixtureWrapper) extends MLWriter {
|
||||||
|
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
|
||||||
|
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||||
|
("dim" -> instance.dim)
|
||||||
|
val rMetadataJson: String = compact(render(rMetadata))
|
||||||
|
|
||||||
|
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||||
|
instance.pipeline.save(pipelinePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] {
|
||||||
|
|
||||||
|
override def load(path: String): GaussianMixtureWrapper = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
val pipeline = PipelineModel.load(pipelinePath)
|
||||||
|
|
||||||
|
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||||
|
val rMetadata = parse(rMetadataStr)
|
||||||
|
val dim = (rMetadata \ "dim").extract[Int]
|
||||||
|
new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -46,6 +46,8 @@ private[r] object RWrappers extends MLReader[Object] {
|
||||||
KMeansWrapper.load(path)
|
KMeansWrapper.load(path)
|
||||||
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
|
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
|
||||||
IsotonicRegressionWrapper.load(path)
|
IsotonicRegressionWrapper.load(path)
|
||||||
|
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
|
||||||
|
GaussianMixtureWrapper.load(path)
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new SparkException(s"SparkR read.ml does not support load $className")
|
throw new SparkException(s"SparkR read.ml does not support load $className")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue