From 87ac84d43729c54be100bb9ad7dc6e8fa14b8805 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 29 Apr 2016 09:42:54 -0700 Subject: [PATCH] [SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans) SparkR ```glm``` and ```kmeans``` model persistence. Unit tests. Author: Yanbo Liang Author: Gayathri Murali Closes #12778 from yanboliang/spark-14311. Closes #12680 Closes #12683 --- R/pkg/R/mllib.R | 98 ++++++++-- R/pkg/inst/tests/testthat/test_mllib.R | 41 ++++ .../ml/r/AFTSurvivalRegressionWrapper.scala | 1 - .../GeneralizedLinearRegressionWrapper.scala | 181 +++++++++++++----- .../org/apache/spark/ml/r/KMeansWrapper.scala | 65 ++++++- .../apache/spark/ml/r/NaiveBayesWrapper.scala | 1 - .../org/apache/spark/ml/r/RWrappers.scala | 4 + 7 files changed, 315 insertions(+), 76 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 480301192d..c2326ea116 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -99,9 +99,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object, ...) { jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") features <- callJMethod(jobj, "rFeatures") coefficients <- callJMethod(jobj, "rCoefficients") - deviance.resid <- callJMethod(jobj, "rDevianceResiduals") dispersion <- callJMethod(jobj, "rDispersion") null.deviance <- callJMethod(jobj, "rNullDeviance") deviance <- callJMethod(jobj, "rDeviance") @@ -110,15 +110,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), aic <- callJMethod(jobj, "rAic") iter <- callJMethod(jobj, "rNumIterations") family <- callJMethod(jobj, "rFamily") - - deviance.resid <- dataFrame(deviance.resid) + deviance.resid <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "rDevianceResiduals")) + } coefficients <- matrix(coefficients, ncol = 4) colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, dispersion = dispersion, null.deviance = null.deviance, deviance = deviance, df.null = df.null, df.residual = df.residual, - aic = aic, iter = iter, family = family) + aic = aic, iter = iter, family = family, is.loaded = is.loaded) class(ans) <- "summary.GeneralizedLinearRegressionModel" return(ans) }) @@ -129,12 +132,16 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @name print.summary.GeneralizedLinearRegressionModel #' @export print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { - x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", + if (x$is.loaded) { + cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") + } else { + x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) - x$deviance.resid <- zapsmall(x$deviance.resid, 5L) - cat("\nDeviance Residuals: \n") - cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") - print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) + x$deviance.resid <- zapsmall(x$deviance.resid, 5L) + cat("\nDeviance Residuals: \n") + cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") + print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) + } cat("\nCoefficients:\n") print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) @@ -246,6 +253,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"), #' Get fitted result from a k-means model #' #' Get fitted result from a k-means model, similarly to R's fitted(). +#' Note: A saved-loaded model does not support this method. #' #' @param object A fitted k-means model #' @return SparkDataFrame containing fitted values @@ -260,7 +268,13 @@ setMethod("kmeans", signature(x = "SparkDataFrame"), setMethod("fitted", signature(object = "KMeansModel"), function(object, method = c("centers", "classes"), ...) { method <- match.arg(method) - return(dataFrame(callJMethod(object@jobj, "fitted", method))) + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + if (is.loaded) { + stop(paste("Saved-loaded k-means model does not support 'fitted' method")) + } else { + return(dataFrame(callJMethod(jobj, "fitted", method))) + } }) #' Get the summary of a k-means model @@ -280,15 +294,21 @@ setMethod("fitted", signature(object = "KMeansModel"), setMethod("summary", signature(object = "KMeansModel"), function(object, ...) { jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") features <- callJMethod(jobj, "features") coefficients <- callJMethod(jobj, "coefficients") - cluster <- callJMethod(jobj, "cluster") k <- callJMethod(jobj, "k") size <- callJMethod(jobj, "size") coefficients <- t(matrix(coefficients, ncol = k)) colnames(coefficients) <- unlist(features) rownames(coefficients) <- 1:k - return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + cluster <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "cluster")) + } + return(list(coefficients = coefficients, size = size, + cluster = cluster, is.loaded = is.loaded)) }) #' Make predictions from a k-means model @@ -389,6 +409,56 @@ setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "ch invisible(callJMethod(writer, "save", path)) }) +#' Save the generalized linear model to the input path. +#' +#' @param object A fitted generalized linear model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' model <- glm(y ~ x, trainingData) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + +#' Save the k-means model to the input path. +#' +#' @param object A fitted k-means model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "KMeansModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + #' Load a fitted MLlib model from the input path. #' #' @param path Path of the model to read. @@ -408,6 +478,10 @@ ml.load <- function(path) { return(new("NaiveBayesModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { return(new("AFTSurvivalRegressionModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) { + return(new("GeneralizedLinearRegressionModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { + return(new("KMeansModel", jobj = jobj)) } else { stop(paste("Unsupported model: ", jobj)) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 954abb00d4..6a822be121 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -126,6 +126,33 @@ test_that("glm summary", { expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) +test_that("glm save/load", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + s <- summary(m) + + modelPath <- tempfile(pattern = "glm", fileext = ".tmp") + ml.save(m, modelPath) + expect_error(ml.save(m, modelPath)) + ml.save(m, modelPath, overwrite = TRUE) + m2 <- ml.load(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + test_that("kmeans", { newIris <- iris newIris$Species <- NULL @@ -150,6 +177,20 @@ test_that("kmeans", { summary.model <- summary(model) cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp") + ml.save(model, modelPath) + expect_error(ml.save(model, modelPath)) + ml.save(model, modelPath, overwrite = TRUE) + model2 <- ml.load(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) }) test_that("naiveBayes", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index a442469e4d..5462f80d69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ -import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index f66323e36c..9618a3423e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -17,65 +17,34 @@ package org.apache.spark.ml.r +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression._ +import org.apache.spark.ml.util._ import org.apache.spark.sql._ private[r] class GeneralizedLinearRegressionWrapper private ( - pipeline: PipelineModel, - val features: Array[String]) { + val pipeline: PipelineModel, + val rFeatures: Array[String], + val rCoefficients: Array[Double], + val rDispersion: Double, + val rNullDeviance: Double, + val rDeviance: Double, + val rResidualDegreeOfFreedomNull: Long, + val rResidualDegreeOfFreedom: Long, + val rAic: Double, + val rNumIterations: Int, + val isLoaded: Boolean = false) extends MLWritable { private val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] - lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { - Array("(Intercept)") ++ features - } else { - features - } - - lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { - Array(glm.intercept) ++ glm.coefficients.toArray ++ - rCoefficientStandardErrors ++ rTValues ++ rPValues - } else { - glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues - } - - private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) { - Array(glm.summary.coefficientStandardErrors.last) ++ - glm.summary.coefficientStandardErrors.dropRight(1) - } else { - glm.summary.coefficientStandardErrors - } - - private lazy val rTValues = if (glm.getFitIntercept) { - Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1) - } else { - glm.summary.tValues - } - - private lazy val rPValues = if (glm.getFitIntercept) { - Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1) - } else { - glm.summary.pValues - } - - lazy val rDispersion: Double = glm.summary.dispersion - - lazy val rNullDeviance: Double = glm.summary.nullDeviance - - lazy val rDeviance: Double = glm.summary.deviance - - lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull - - lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom - - lazy val rAic: Double = glm.summary.aic - - lazy val rNumIterations: Int = glm.summary.numIterations - lazy val rDevianceResiduals: DataFrame = glm.summary.residuals() lazy val rFamily: String = glm.getFamily @@ -85,9 +54,13 @@ private[r] class GeneralizedLinearRegressionWrapper private ( def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(glm.getFeaturesCol) } + + override def write: MLWriter = + new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this) } -private[r] object GeneralizedLinearRegressionWrapper { +private[r] object GeneralizedLinearRegressionWrapper + extends MLReadable[GeneralizedLinearRegressionWrapper] { def fit( formula: String, @@ -105,15 +78,119 @@ private[r] object GeneralizedLinearRegressionWrapper { .attributes.get val features = featureAttrs.map(_.name.get) // assemble and fit the pipeline - val glm = new GeneralizedLinearRegression() + val glr = new GeneralizedLinearRegression() .setFamily(family) .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(epsilon) .setMaxIter(maxit) val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, glm)) + .setStages(Array(rFormulaModel, glr)) .fit(data) - new GeneralizedLinearRegressionWrapper(pipeline, features) + + val glm: GeneralizedLinearRegressionModel = + pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + val summary = glm.summary + + val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + + val rCoefficientStandardErrors = if (glm.getFitIntercept) { + Array(summary.coefficientStandardErrors.last) ++ + summary.coefficientStandardErrors.dropRight(1) + } else { + summary.coefficientStandardErrors + } + + val rTValues = if (glm.getFitIntercept) { + Array(summary.tValues.last) ++ summary.tValues.dropRight(1) + } else { + summary.tValues + } + + val rPValues = if (glm.getFitIntercept) { + Array(summary.pValues.last) ++ summary.pValues.dropRight(1) + } else { + summary.pValues + } + + val rCoefficients: Array[Double] = if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray ++ + rCoefficientStandardErrors ++ rTValues ++ rPValues + } else { + glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues + } + + val rDispersion: Double = summary.dispersion + val rNullDeviance: Double = summary.nullDeviance + val rDeviance: Double = summary.deviance + val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull + val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom + val rAic: Double = summary.aic + val rNumIterations: Int = summary.numIterations + + new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, + rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom, + rAic, rNumIterations) + } + + override def read: MLReader[GeneralizedLinearRegressionWrapper] = + new GeneralizedLinearRegressionWrapperReader + + override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path) + + class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("rFeatures" -> instance.rFeatures.toSeq) ~ + ("rCoefficients" -> instance.rCoefficients.toSeq) ~ + ("rDispersion" -> instance.rDispersion) ~ + ("rNullDeviance" -> instance.rNullDeviance) ~ + ("rDeviance" -> instance.rDeviance) ~ + ("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~ + ("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~ + ("rAic" -> instance.rAic) ~ + ("rNumIterations" -> instance.rNumIterations) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class GeneralizedLinearRegressionWrapperReader + extends MLReader[GeneralizedLinearRegressionWrapper] { + + override def load(path: String): GeneralizedLinearRegressionWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]] + val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]] + val rDispersion = (rMetadata \ "rDispersion").extract[Double] + val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double] + val rDeviance = (rMetadata \ "rDeviance").extract[Double] + val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long] + val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long] + val rAic = (rMetadata \ "rAic").extract[Double] + val rNumIterations = (rMetadata \ "rNumIterations").extract[Int] + + val pipeline = PipelineModel.load(pipelinePath) + + new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, + rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom, + rAic, rNumIterations, isLoaded = true) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index 9e2b81ee20..f67760d3ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -17,28 +17,30 @@ package org.apache.spark.ml.r +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class KMeansWrapper private ( - pipeline: PipelineModel) { + val pipeline: PipelineModel, + val features: Array[String], + val size: Array[Long], + val isLoaded: Boolean = false) extends MLWritable { private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray) - private lazy val attrs = AttributeGroup.fromStructField( - kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) - - lazy val features: Array[String] = attrs.attributes.get.map(_.name.get) - lazy val k: Int = kMeansModel.getK - lazy val size: Array[Long] = kMeansModel.summary.clusterSizes - lazy val cluster: DataFrame = kMeansModel.summary.cluster def fitted(method: String): DataFrame = { @@ -56,9 +58,10 @@ private[r] class KMeansWrapper private ( pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) } + override def write: MLWriter = new KMeansWrapper.KMeansWrapperWriter(this) } -private[r] object KMeansWrapper { +private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { def fit( data: DataFrame, @@ -80,6 +83,48 @@ private[r] object KMeansWrapper { .setStages(Array(assembler, kMeans)) .fit(data) - new KMeansWrapper(pipeline) + val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] + val attrs = AttributeGroup.fromStructField( + kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) + val features: Array[String] = attrs.attributes.get.map(_.name.get) + val size: Array[Long] = kMeansModel.summary.clusterSizes + + new KMeansWrapper(pipeline, features, size) + } + + override def read: MLReader[KMeansWrapper] = new KMeansWrapperReader + + override def load(path: String): KMeansWrapper = super.load(path) + + class KMeansWrapperWriter(instance: KMeansWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("features" -> instance.features.toSeq) ~ + ("size" -> instance.size.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class KMeansWrapperReader extends MLReader[KMeansWrapper] { + + override def load(path: String): KMeansWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val features = (rMetadata \ "features").extract[Array[String]] + val size = (rMetadata \ "size").extract[Array[Long]] + new KMeansWrapper(pipeline, features, size, isLoaded = true) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 27c7e72881..28925c79da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.r import org.apache.hadoop.fs.Path import org.json4s._ -import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 06baedf2a2..9c0757941e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -40,6 +40,10 @@ private[r] object RWrappers extends MLReader[Object] { case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" => AFTSurvivalRegressionWrapper.load(path) + case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" => + GeneralizedLinearRegressionWrapper.load(path) + case "org.apache.spark.ml.r.KMeansWrapper" => + KMeansWrapper.load(path) case _ => throw new SparkException(s"SparkR ml.load does not support load $className") }