[SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)

SparkR ```glm``` and ```kmeans``` model persistence.

Unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>
Author: Gayathri Murali <gayathri.m.softie@gmail.com>

Closes #12778 from yanboliang/spark-14311.
Closes #12680
Closes #12683
This commit is contained in:
Yanbo Liang 2016-04-29 09:42:54 -07:00 committed by Xiangrui Meng
parent a7d0fedc94
commit 87ac84d437
7 changed files with 315 additions and 76 deletions

View file

@ -99,9 +99,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
function(object, ...) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
deviance.resid <- callJMethod(jobj, "rDevianceResiduals")
dispersion <- callJMethod(jobj, "rDispersion")
null.deviance <- callJMethod(jobj, "rNullDeviance")
deviance <- callJMethod(jobj, "rDeviance")
@ -110,15 +110,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
aic <- callJMethod(jobj, "rAic")
iter <- callJMethod(jobj, "rNumIterations")
family <- callJMethod(jobj, "rFamily")
deviance.resid <- dataFrame(deviance.resid)
deviance.resid <- if (is.loaded) {
NULL
} else {
dataFrame(callJMethod(jobj, "rDevianceResiduals"))
}
coefficients <- matrix(coefficients, ncol = 4)
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
dispersion = dispersion, null.deviance = null.deviance,
deviance = deviance, df.null = df.null, df.residual = df.residual,
aic = aic, iter = iter, family = family)
aic = aic, iter = iter, family = family, is.loaded = is.loaded)
class(ans) <- "summary.GeneralizedLinearRegressionModel"
return(ans)
})
@ -129,12 +132,16 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
#' @name print.summary.GeneralizedLinearRegressionModel
#' @export
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
if (x$is.loaded) {
cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
} else {
x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
cat("\nDeviance Residuals: \n")
cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
cat("\nDeviance Residuals: \n")
cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
}
cat("\nCoefficients:\n")
print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
@ -246,6 +253,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
#' Get fitted result from a k-means model
#'
#' Get fitted result from a k-means model, similarly to R's fitted().
#' Note: A saved-loaded model does not support this method.
#'
#' @param object A fitted k-means model
#' @return SparkDataFrame containing fitted values
@ -260,7 +268,13 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
method <- match.arg(method)
return(dataFrame(callJMethod(object@jobj, "fitted", method)))
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
if (is.loaded) {
stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
} else {
return(dataFrame(callJMethod(jobj, "fitted", method)))
}
})
#' Get the summary of a k-means model
@ -280,15 +294,21 @@ setMethod("fitted", signature(object = "KMeansModel"),
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "features")
coefficients <- callJMethod(jobj, "coefficients")
cluster <- callJMethod(jobj, "cluster")
k <- callJMethod(jobj, "k")
size <- callJMethod(jobj, "size")
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
cluster <- if (is.loaded) {
NULL
} else {
dataFrame(callJMethod(jobj, "cluster"))
}
return(list(coefficients = coefficients, size = size,
cluster = cluster, is.loaded = is.loaded))
})
#' Make predictions from a k-means model
@ -389,6 +409,56 @@ setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "ch
invisible(callJMethod(writer, "save", path))
})
#' Save the generalized linear model to the input path.
#'
#' @param object A fitted generalized linear model
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname ml.save
#' @name ml.save
#' @export
#' @examples
#' \dontrun{
#' model <- glm(y ~ x, trainingData)
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
})
#' Save the k-means model to the input path.
#'
#' @param object A fitted k-means model
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname ml.save
#' @name ml.save
#' @export
#' @examples
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
setMethod("ml.save", signature(object = "KMeansModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
})
#' Load a fitted MLlib model from the input path.
#'
#' @param path Path of the model to read.
@ -408,6 +478,10 @@ ml.load <- function(path) {
return(new("NaiveBayesModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
return(new("AFTSurvivalRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}

View file

@ -126,6 +126,33 @@ test_that("glm summary", {
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
test_that("glm save/load", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
s <- summary(m)
modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
ml.save(m, modelPath)
expect_error(ml.save(m, modelPath))
ml.save(m, modelPath, overwrite = TRUE)
m2 <- ml.load(modelPath)
s2 <- summary(m2)
expect_equal(s$coefficients, s2$coefficients)
expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
expect_equal(s$dispersion, s2$dispersion)
expect_equal(s$null.deviance, s2$null.deviance)
expect_equal(s$deviance, s2$deviance)
expect_equal(s$df.null, s2$df.null)
expect_equal(s$df.residual, s2$df.residual)
expect_equal(s$aic, s2$aic)
expect_equal(s$iter, s2$iter)
expect_true(!s$is.loaded)
expect_true(s2$is.loaded)
unlink(modelPath)
})
test_that("kmeans", {
newIris <- iris
newIris$Species <- NULL
@ -150,6 +177,20 @@ test_that("kmeans", {
summary.model <- summary(model)
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
# Test model save/load
modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
ml.save(model, modelPath)
expect_error(ml.save(model, modelPath))
ml.save(model, modelPath, overwrite = TRUE)
model2 <- ml.load(modelPath)
summary2 <- summary(model2)
expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
expect_equal(summary.model$coefficients, summary2$coefficients)
expect_true(!summary.model$is.loaded)
expect_true(summary2$is.loaded)
unlink(modelPath)
})
test_that("naiveBayes", {

View file

@ -19,7 +19,6 @@ package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

View file

@ -17,65 +17,34 @@
package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
private[r] class GeneralizedLinearRegressionWrapper private (
pipeline: PipelineModel,
val features: Array[String]) {
val pipeline: PipelineModel,
val rFeatures: Array[String],
val rCoefficients: Array[Double],
val rDispersion: Double,
val rNullDeviance: Double,
val rDeviance: Double,
val rResidualDegreeOfFreedomNull: Long,
val rResidualDegreeOfFreedom: Long,
val rAic: Double,
val rNumIterations: Int,
val isLoaded: Boolean = false) extends MLWritable {
private val glm: GeneralizedLinearRegressionModel =
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
Array("(Intercept)") ++ features
} else {
features
}
lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
Array(glm.intercept) ++ glm.coefficients.toArray ++
rCoefficientStandardErrors ++ rTValues ++ rPValues
} else {
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
}
private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) {
Array(glm.summary.coefficientStandardErrors.last) ++
glm.summary.coefficientStandardErrors.dropRight(1)
} else {
glm.summary.coefficientStandardErrors
}
private lazy val rTValues = if (glm.getFitIntercept) {
Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1)
} else {
glm.summary.tValues
}
private lazy val rPValues = if (glm.getFitIntercept) {
Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1)
} else {
glm.summary.pValues
}
lazy val rDispersion: Double = glm.summary.dispersion
lazy val rNullDeviance: Double = glm.summary.nullDeviance
lazy val rDeviance: Double = glm.summary.deviance
lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull
lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom
lazy val rAic: Double = glm.summary.aic
lazy val rNumIterations: Int = glm.summary.numIterations
lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
lazy val rFamily: String = glm.getFamily
@ -85,9 +54,13 @@ private[r] class GeneralizedLinearRegressionWrapper private (
def transform(dataset: Dataset[_]): DataFrame = {
pipeline.transform(dataset).drop(glm.getFeaturesCol)
}
override def write: MLWriter =
new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this)
}
private[r] object GeneralizedLinearRegressionWrapper {
private[r] object GeneralizedLinearRegressionWrapper
extends MLReadable[GeneralizedLinearRegressionWrapper] {
def fit(
formula: String,
@ -105,15 +78,119 @@ private[r] object GeneralizedLinearRegressionWrapper {
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
val glm = new GeneralizedLinearRegression()
val glr = new GeneralizedLinearRegression()
.setFamily(family)
.setLink(link)
.setFitIntercept(rFormula.hasIntercept)
.setTol(epsilon)
.setMaxIter(maxit)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glm))
.setStages(Array(rFormulaModel, glr))
.fit(data)
new GeneralizedLinearRegressionWrapper(pipeline, features)
val glm: GeneralizedLinearRegressionModel =
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
val summary = glm.summary
val rFeatures: Array[String] = if (glm.getFitIntercept) {
Array("(Intercept)") ++ features
} else {
features
}
val rCoefficientStandardErrors = if (glm.getFitIntercept) {
Array(summary.coefficientStandardErrors.last) ++
summary.coefficientStandardErrors.dropRight(1)
} else {
summary.coefficientStandardErrors
}
val rTValues = if (glm.getFitIntercept) {
Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
} else {
summary.tValues
}
val rPValues = if (glm.getFitIntercept) {
Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
} else {
summary.pValues
}
val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
Array(glm.intercept) ++ glm.coefficients.toArray ++
rCoefficientStandardErrors ++ rTValues ++ rPValues
} else {
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
}
val rDispersion: Double = summary.dispersion
val rNullDeviance: Double = summary.nullDeviance
val rDeviance: Double = summary.deviance
val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull
val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom
val rAic: Double = summary.aic
val rNumIterations: Int = summary.numIterations
new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
rAic, rNumIterations)
}
override def read: MLReader[GeneralizedLinearRegressionWrapper] =
new GeneralizedLinearRegressionWrapperReader
override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path)
class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
("rFeatures" -> instance.rFeatures.toSeq) ~
("rCoefficients" -> instance.rCoefficients.toSeq) ~
("rDispersion" -> instance.rDispersion) ~
("rNullDeviance" -> instance.rNullDeviance) ~
("rDeviance" -> instance.rDeviance) ~
("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~
("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~
("rAic" -> instance.rAic) ~
("rNumIterations" -> instance.rNumIterations)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
instance.pipeline.save(pipelinePath)
}
}
class GeneralizedLinearRegressionWrapperReader
extends MLReader[GeneralizedLinearRegressionWrapper] {
override def load(path: String): GeneralizedLinearRegressionWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]]
val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]]
val rDispersion = (rMetadata \ "rDispersion").extract[Double]
val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double]
val rDeviance = (rMetadata \ "rDeviance").extract[Double]
val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long]
val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long]
val rAic = (rMetadata \ "rAic").extract[Double]
val rNumIterations = (rMetadata \ "rNumIterations").extract[Int]
val pipeline = PipelineModel.load(pipelinePath)
new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
rAic, rNumIterations, isLoaded = true)
}
}
}

View file

@ -17,28 +17,30 @@
package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class KMeansWrapper private (
pipeline: PipelineModel) {
val pipeline: PipelineModel,
val features: Array[String],
val size: Array[Long],
val isLoaded: Boolean = false) extends MLWritable {
private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
private lazy val attrs = AttributeGroup.fromStructField(
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
lazy val k: Int = kMeansModel.getK
lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
lazy val cluster: DataFrame = kMeansModel.summary.cluster
def fitted(method: String): DataFrame = {
@ -56,9 +58,10 @@ private[r] class KMeansWrapper private (
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
}
override def write: MLWriter = new KMeansWrapper.KMeansWrapperWriter(this)
}
private[r] object KMeansWrapper {
private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
def fit(
data: DataFrame,
@ -80,6 +83,48 @@ private[r] object KMeansWrapper {
.setStages(Array(assembler, kMeans))
.fit(data)
new KMeansWrapper(pipeline)
val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
val attrs = AttributeGroup.fromStructField(
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
val features: Array[String] = attrs.attributes.get.map(_.name.get)
val size: Array[Long] = kMeansModel.summary.clusterSizes
new KMeansWrapper(pipeline, features, size)
}
override def read: MLReader[KMeansWrapper] = new KMeansWrapperReader
override def load(path: String): KMeansWrapper = super.load(path)
class KMeansWrapperWriter(instance: KMeansWrapper) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
("features" -> instance.features.toSeq) ~
("size" -> instance.size.toSeq)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
instance.pipeline.save(pipelinePath)
}
}
class KMeansWrapperReader extends MLReader[KMeansWrapper] {
override def load(path: String): KMeansWrapper = {
implicit val format = DefaultFormats
val rMetadataPath = new Path(path, "rMetadata").toString
val pipelinePath = new Path(path, "pipeline").toString
val pipeline = PipelineModel.load(pipelinePath)
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val features = (rMetadata \ "features").extract[Array[String]]
val size = (rMetadata \ "size").extract[Array[Long]]
new KMeansWrapper(pipeline, features, size, isLoaded = true)
}
}
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.ml.r
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

View file

@ -40,6 +40,10 @@ private[r] object RWrappers extends MLReader[Object] {
case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
AFTSurvivalRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR ml.load does not support load $className")
}