[SPARK-14314][SPARK-14315][ML][SPARKR] Model persistence in SparkR (glm & kmeans)
SparkR ```glm``` and ```kmeans``` model persistence. Unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Author: Gayathri Murali <gayathri.m.softie@gmail.com> Closes #12778 from yanboliang/spark-14311. Closes #12680 Closes #12683
This commit is contained in:
parent
a7d0fedc94
commit
87ac84d437
|
@ -99,9 +99,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
|
|||
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
|
||||
function(object, ...) {
|
||||
jobj <- object@jobj
|
||||
is.loaded <- callJMethod(jobj, "isLoaded")
|
||||
features <- callJMethod(jobj, "rFeatures")
|
||||
coefficients <- callJMethod(jobj, "rCoefficients")
|
||||
deviance.resid <- callJMethod(jobj, "rDevianceResiduals")
|
||||
dispersion <- callJMethod(jobj, "rDispersion")
|
||||
null.deviance <- callJMethod(jobj, "rNullDeviance")
|
||||
deviance <- callJMethod(jobj, "rDeviance")
|
||||
|
@ -110,15 +110,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
|
|||
aic <- callJMethod(jobj, "rAic")
|
||||
iter <- callJMethod(jobj, "rNumIterations")
|
||||
family <- callJMethod(jobj, "rFamily")
|
||||
|
||||
deviance.resid <- dataFrame(deviance.resid)
|
||||
deviance.resid <- if (is.loaded) {
|
||||
NULL
|
||||
} else {
|
||||
dataFrame(callJMethod(jobj, "rDevianceResiduals"))
|
||||
}
|
||||
coefficients <- matrix(coefficients, ncol = 4)
|
||||
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
|
||||
rownames(coefficients) <- unlist(features)
|
||||
ans <- list(deviance.resid = deviance.resid, coefficients = coefficients,
|
||||
dispersion = dispersion, null.deviance = null.deviance,
|
||||
deviance = deviance, df.null = df.null, df.residual = df.residual,
|
||||
aic = aic, iter = iter, family = family)
|
||||
aic = aic, iter = iter, family = family, is.loaded = is.loaded)
|
||||
class(ans) <- "summary.GeneralizedLinearRegressionModel"
|
||||
return(ans)
|
||||
})
|
||||
|
@ -129,12 +132,16 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
|
|||
#' @name print.summary.GeneralizedLinearRegressionModel
|
||||
#' @export
|
||||
print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
|
||||
x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
|
||||
if (x$is.loaded) {
|
||||
cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n")
|
||||
} else {
|
||||
x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals",
|
||||
c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max"))
|
||||
x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
|
||||
cat("\nDeviance Residuals: \n")
|
||||
cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
|
||||
print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
|
||||
x$deviance.resid <- zapsmall(x$deviance.resid, 5L)
|
||||
cat("\nDeviance Residuals: \n")
|
||||
cat("(Note: These are approximate quantiles with relative error <= 0.01)\n")
|
||||
print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L)
|
||||
}
|
||||
|
||||
cat("\nCoefficients:\n")
|
||||
print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L)
|
||||
|
@ -246,6 +253,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
|
|||
#' Get fitted result from a k-means model
|
||||
#'
|
||||
#' Get fitted result from a k-means model, similarly to R's fitted().
|
||||
#' Note: A saved-loaded model does not support this method.
|
||||
#'
|
||||
#' @param object A fitted k-means model
|
||||
#' @return SparkDataFrame containing fitted values
|
||||
|
@ -260,7 +268,13 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
|
|||
setMethod("fitted", signature(object = "KMeansModel"),
|
||||
function(object, method = c("centers", "classes"), ...) {
|
||||
method <- match.arg(method)
|
||||
return(dataFrame(callJMethod(object@jobj, "fitted", method)))
|
||||
jobj <- object@jobj
|
||||
is.loaded <- callJMethod(jobj, "isLoaded")
|
||||
if (is.loaded) {
|
||||
stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
|
||||
} else {
|
||||
return(dataFrame(callJMethod(jobj, "fitted", method)))
|
||||
}
|
||||
})
|
||||
|
||||
#' Get the summary of a k-means model
|
||||
|
@ -280,15 +294,21 @@ setMethod("fitted", signature(object = "KMeansModel"),
|
|||
setMethod("summary", signature(object = "KMeansModel"),
|
||||
function(object, ...) {
|
||||
jobj <- object@jobj
|
||||
is.loaded <- callJMethod(jobj, "isLoaded")
|
||||
features <- callJMethod(jobj, "features")
|
||||
coefficients <- callJMethod(jobj, "coefficients")
|
||||
cluster <- callJMethod(jobj, "cluster")
|
||||
k <- callJMethod(jobj, "k")
|
||||
size <- callJMethod(jobj, "size")
|
||||
coefficients <- t(matrix(coefficients, ncol = k))
|
||||
colnames(coefficients) <- unlist(features)
|
||||
rownames(coefficients) <- 1:k
|
||||
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
|
||||
cluster <- if (is.loaded) {
|
||||
NULL
|
||||
} else {
|
||||
dataFrame(callJMethod(jobj, "cluster"))
|
||||
}
|
||||
return(list(coefficients = coefficients, size = size,
|
||||
cluster = cluster, is.loaded = is.loaded))
|
||||
})
|
||||
|
||||
#' Make predictions from a k-means model
|
||||
|
@ -389,6 +409,56 @@ setMethod("ml.save", signature(object = "AFTSurvivalRegressionModel", path = "ch
|
|||
invisible(callJMethod(writer, "save", path))
|
||||
})
|
||||
|
||||
#' Save the generalized linear model to the input path.
|
||||
#'
|
||||
#' @param object A fitted generalized linear model
|
||||
#' @param path The directory where the model is saved
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
#'
|
||||
#' @rdname ml.save
|
||||
#' @name ml.save
|
||||
#' @export
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' model <- glm(y ~ x, trainingData)
|
||||
#' path <- "path/to/model"
|
||||
#' ml.save(model, path)
|
||||
#' }
|
||||
setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
writer <- callJMethod(object@jobj, "write")
|
||||
if (overwrite) {
|
||||
writer <- callJMethod(writer, "overwrite")
|
||||
}
|
||||
invisible(callJMethod(writer, "save", path))
|
||||
})
|
||||
|
||||
#' Save the k-means model to the input path.
|
||||
#'
|
||||
#' @param object A fitted k-means model
|
||||
#' @param path The directory where the model is saved
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
#'
|
||||
#' @rdname ml.save
|
||||
#' @name ml.save
|
||||
#' @export
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' model <- kmeans(x, centers = 2, algorithm="random")
|
||||
#' path <- "path/to/model"
|
||||
#' ml.save(model, path)
|
||||
#' }
|
||||
setMethod("ml.save", signature(object = "KMeansModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
writer <- callJMethod(object@jobj, "write")
|
||||
if (overwrite) {
|
||||
writer <- callJMethod(writer, "overwrite")
|
||||
}
|
||||
invisible(callJMethod(writer, "save", path))
|
||||
})
|
||||
|
||||
#' Load a fitted MLlib model from the input path.
|
||||
#'
|
||||
#' @param path Path of the model to read.
|
||||
|
@ -408,6 +478,10 @@ ml.load <- function(path) {
|
|||
return(new("NaiveBayesModel", jobj = jobj))
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
|
||||
return(new("AFTSurvivalRegressionModel", jobj = jobj))
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
|
||||
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
|
||||
return(new("KMeansModel", jobj = jobj))
|
||||
} else {
|
||||
stop(paste("Unsupported model: ", jobj))
|
||||
}
|
||||
|
|
|
@ -126,6 +126,33 @@ test_that("glm summary", {
|
|||
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
|
||||
})
|
||||
|
||||
test_that("glm save/load", {
|
||||
training <- suppressWarnings(createDataFrame(sqlContext, iris))
|
||||
m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
|
||||
s <- summary(m)
|
||||
|
||||
modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
|
||||
ml.save(m, modelPath)
|
||||
expect_error(ml.save(m, modelPath))
|
||||
ml.save(m, modelPath, overwrite = TRUE)
|
||||
m2 <- ml.load(modelPath)
|
||||
s2 <- summary(m2)
|
||||
|
||||
expect_equal(s$coefficients, s2$coefficients)
|
||||
expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
|
||||
expect_equal(s$dispersion, s2$dispersion)
|
||||
expect_equal(s$null.deviance, s2$null.deviance)
|
||||
expect_equal(s$deviance, s2$deviance)
|
||||
expect_equal(s$df.null, s2$df.null)
|
||||
expect_equal(s$df.residual, s2$df.residual)
|
||||
expect_equal(s$aic, s2$aic)
|
||||
expect_equal(s$iter, s2$iter)
|
||||
expect_true(!s$is.loaded)
|
||||
expect_true(s2$is.loaded)
|
||||
|
||||
unlink(modelPath)
|
||||
})
|
||||
|
||||
test_that("kmeans", {
|
||||
newIris <- iris
|
||||
newIris$Species <- NULL
|
||||
|
@ -150,6 +177,20 @@ test_that("kmeans", {
|
|||
summary.model <- summary(model)
|
||||
cluster <- summary.model$cluster
|
||||
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
|
||||
|
||||
# Test model save/load
|
||||
modelPath <- tempfile(pattern = "kmeans", fileext = ".tmp")
|
||||
ml.save(model, modelPath)
|
||||
expect_error(ml.save(model, modelPath))
|
||||
ml.save(model, modelPath, overwrite = TRUE)
|
||||
model2 <- ml.load(modelPath)
|
||||
summary2 <- summary(model2)
|
||||
expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
|
||||
expect_equal(summary.model$coefficients, summary2$coefficients)
|
||||
expect_true(!summary.model$is.loaded)
|
||||
expect_true(summary2$is.loaded)
|
||||
|
||||
unlink(modelPath)
|
||||
})
|
||||
|
||||
test_that("naiveBayes", {
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.ml.r
|
|||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
|
|
|
@ -17,65 +17,34 @@
|
|||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.regression._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql._
|
||||
|
||||
private[r] class GeneralizedLinearRegressionWrapper private (
|
||||
pipeline: PipelineModel,
|
||||
val features: Array[String]) {
|
||||
val pipeline: PipelineModel,
|
||||
val rFeatures: Array[String],
|
||||
val rCoefficients: Array[Double],
|
||||
val rDispersion: Double,
|
||||
val rNullDeviance: Double,
|
||||
val rDeviance: Double,
|
||||
val rResidualDegreeOfFreedomNull: Long,
|
||||
val rResidualDegreeOfFreedom: Long,
|
||||
val rAic: Double,
|
||||
val rNumIterations: Int,
|
||||
val isLoaded: Boolean = false) extends MLWritable {
|
||||
|
||||
private val glm: GeneralizedLinearRegressionModel =
|
||||
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
|
||||
|
||||
lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
|
||||
Array("(Intercept)") ++ features
|
||||
} else {
|
||||
features
|
||||
}
|
||||
|
||||
lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
|
||||
Array(glm.intercept) ++ glm.coefficients.toArray ++
|
||||
rCoefficientStandardErrors ++ rTValues ++ rPValues
|
||||
} else {
|
||||
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
|
||||
}
|
||||
|
||||
private lazy val rCoefficientStandardErrors = if (glm.getFitIntercept) {
|
||||
Array(glm.summary.coefficientStandardErrors.last) ++
|
||||
glm.summary.coefficientStandardErrors.dropRight(1)
|
||||
} else {
|
||||
glm.summary.coefficientStandardErrors
|
||||
}
|
||||
|
||||
private lazy val rTValues = if (glm.getFitIntercept) {
|
||||
Array(glm.summary.tValues.last) ++ glm.summary.tValues.dropRight(1)
|
||||
} else {
|
||||
glm.summary.tValues
|
||||
}
|
||||
|
||||
private lazy val rPValues = if (glm.getFitIntercept) {
|
||||
Array(glm.summary.pValues.last) ++ glm.summary.pValues.dropRight(1)
|
||||
} else {
|
||||
glm.summary.pValues
|
||||
}
|
||||
|
||||
lazy val rDispersion: Double = glm.summary.dispersion
|
||||
|
||||
lazy val rNullDeviance: Double = glm.summary.nullDeviance
|
||||
|
||||
lazy val rDeviance: Double = glm.summary.deviance
|
||||
|
||||
lazy val rResidualDegreeOfFreedomNull: Long = glm.summary.residualDegreeOfFreedomNull
|
||||
|
||||
lazy val rResidualDegreeOfFreedom: Long = glm.summary.residualDegreeOfFreedom
|
||||
|
||||
lazy val rAic: Double = glm.summary.aic
|
||||
|
||||
lazy val rNumIterations: Int = glm.summary.numIterations
|
||||
|
||||
lazy val rDevianceResiduals: DataFrame = glm.summary.residuals()
|
||||
|
||||
lazy val rFamily: String = glm.getFamily
|
||||
|
@ -85,9 +54,13 @@ private[r] class GeneralizedLinearRegressionWrapper private (
|
|||
def transform(dataset: Dataset[_]): DataFrame = {
|
||||
pipeline.transform(dataset).drop(glm.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter =
|
||||
new GeneralizedLinearRegressionWrapper.GeneralizedLinearRegressionWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object GeneralizedLinearRegressionWrapper {
|
||||
private[r] object GeneralizedLinearRegressionWrapper
|
||||
extends MLReadable[GeneralizedLinearRegressionWrapper] {
|
||||
|
||||
def fit(
|
||||
formula: String,
|
||||
|
@ -105,15 +78,119 @@ private[r] object GeneralizedLinearRegressionWrapper {
|
|||
.attributes.get
|
||||
val features = featureAttrs.map(_.name.get)
|
||||
// assemble and fit the pipeline
|
||||
val glm = new GeneralizedLinearRegression()
|
||||
val glr = new GeneralizedLinearRegression()
|
||||
.setFamily(family)
|
||||
.setLink(link)
|
||||
.setFitIntercept(rFormula.hasIntercept)
|
||||
.setTol(epsilon)
|
||||
.setMaxIter(maxit)
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(rFormulaModel, glm))
|
||||
.setStages(Array(rFormulaModel, glr))
|
||||
.fit(data)
|
||||
new GeneralizedLinearRegressionWrapper(pipeline, features)
|
||||
|
||||
val glm: GeneralizedLinearRegressionModel =
|
||||
pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
|
||||
val summary = glm.summary
|
||||
|
||||
val rFeatures: Array[String] = if (glm.getFitIntercept) {
|
||||
Array("(Intercept)") ++ features
|
||||
} else {
|
||||
features
|
||||
}
|
||||
|
||||
val rCoefficientStandardErrors = if (glm.getFitIntercept) {
|
||||
Array(summary.coefficientStandardErrors.last) ++
|
||||
summary.coefficientStandardErrors.dropRight(1)
|
||||
} else {
|
||||
summary.coefficientStandardErrors
|
||||
}
|
||||
|
||||
val rTValues = if (glm.getFitIntercept) {
|
||||
Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
|
||||
} else {
|
||||
summary.tValues
|
||||
}
|
||||
|
||||
val rPValues = if (glm.getFitIntercept) {
|
||||
Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
|
||||
} else {
|
||||
summary.pValues
|
||||
}
|
||||
|
||||
val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
|
||||
Array(glm.intercept) ++ glm.coefficients.toArray ++
|
||||
rCoefficientStandardErrors ++ rTValues ++ rPValues
|
||||
} else {
|
||||
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
|
||||
}
|
||||
|
||||
val rDispersion: Double = summary.dispersion
|
||||
val rNullDeviance: Double = summary.nullDeviance
|
||||
val rDeviance: Double = summary.deviance
|
||||
val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull
|
||||
val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom
|
||||
val rAic: Double = summary.aic
|
||||
val rNumIterations: Int = summary.numIterations
|
||||
|
||||
new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
|
||||
rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
|
||||
rAic, rNumIterations)
|
||||
}
|
||||
|
||||
override def read: MLReader[GeneralizedLinearRegressionWrapper] =
|
||||
new GeneralizedLinearRegressionWrapperReader
|
||||
|
||||
override def load(path: String): GeneralizedLinearRegressionWrapper = super.load(path)
|
||||
|
||||
class GeneralizedLinearRegressionWrapperWriter(instance: GeneralizedLinearRegressionWrapper)
|
||||
extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("rFeatures" -> instance.rFeatures.toSeq) ~
|
||||
("rCoefficients" -> instance.rCoefficients.toSeq) ~
|
||||
("rDispersion" -> instance.rDispersion) ~
|
||||
("rNullDeviance" -> instance.rNullDeviance) ~
|
||||
("rDeviance" -> instance.rDeviance) ~
|
||||
("rResidualDegreeOfFreedomNull" -> instance.rResidualDegreeOfFreedomNull) ~
|
||||
("rResidualDegreeOfFreedom" -> instance.rResidualDegreeOfFreedom) ~
|
||||
("rAic" -> instance.rAic) ~
|
||||
("rNumIterations" -> instance.rNumIterations)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class GeneralizedLinearRegressionWrapperReader
|
||||
extends MLReader[GeneralizedLinearRegressionWrapper] {
|
||||
|
||||
override def load(path: String): GeneralizedLinearRegressionWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]]
|
||||
val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]]
|
||||
val rDispersion = (rMetadata \ "rDispersion").extract[Double]
|
||||
val rNullDeviance = (rMetadata \ "rNullDeviance").extract[Double]
|
||||
val rDeviance = (rMetadata \ "rDeviance").extract[Double]
|
||||
val rResidualDegreeOfFreedomNull = (rMetadata \ "rResidualDegreeOfFreedomNull").extract[Long]
|
||||
val rResidualDegreeOfFreedom = (rMetadata \ "rResidualDegreeOfFreedom").extract[Long]
|
||||
val rAic = (rMetadata \ "rAic").extract[Double]
|
||||
val rNumIterations = (rMetadata \ "rNumIterations").extract[Int]
|
||||
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
|
||||
new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion,
|
||||
rNullDeviance, rDeviance, rResidualDegreeOfFreedomNull, rResidualDegreeOfFreedom,
|
||||
rAic, rNumIterations, isLoaded = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,28 +17,30 @@
|
|||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
private[r] class KMeansWrapper private (
|
||||
pipeline: PipelineModel) {
|
||||
val pipeline: PipelineModel,
|
||||
val features: Array[String],
|
||||
val size: Array[Long],
|
||||
val isLoaded: Boolean = false) extends MLWritable {
|
||||
|
||||
private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
|
||||
|
||||
lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
|
||||
|
||||
private lazy val attrs = AttributeGroup.fromStructField(
|
||||
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
|
||||
|
||||
lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
|
||||
|
||||
lazy val k: Int = kMeansModel.getK
|
||||
|
||||
lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
|
||||
|
||||
lazy val cluster: DataFrame = kMeansModel.summary.cluster
|
||||
|
||||
def fitted(method: String): DataFrame = {
|
||||
|
@ -56,9 +58,10 @@ private[r] class KMeansWrapper private (
|
|||
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new KMeansWrapper.KMeansWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object KMeansWrapper {
|
||||
private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
|
||||
|
||||
def fit(
|
||||
data: DataFrame,
|
||||
|
@ -80,6 +83,48 @@ private[r] object KMeansWrapper {
|
|||
.setStages(Array(assembler, kMeans))
|
||||
.fit(data)
|
||||
|
||||
new KMeansWrapper(pipeline)
|
||||
val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
|
||||
val attrs = AttributeGroup.fromStructField(
|
||||
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
|
||||
val features: Array[String] = attrs.attributes.get.map(_.name.get)
|
||||
val size: Array[Long] = kMeansModel.summary.clusterSizes
|
||||
|
||||
new KMeansWrapper(pipeline, features, size)
|
||||
}
|
||||
|
||||
override def read: MLReader[KMeansWrapper] = new KMeansWrapperReader
|
||||
|
||||
override def load(path: String): KMeansWrapper = super.load(path)
|
||||
|
||||
class KMeansWrapperWriter(instance: KMeansWrapper) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("features" -> instance.features.toSeq) ~
|
||||
("size" -> instance.size.toSeq)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class KMeansWrapperReader extends MLReader[KMeansWrapper] {
|
||||
|
||||
override def load(path: String): KMeansWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val features = (rMetadata \ "features").extract[Array[String]]
|
||||
val size = (rMetadata \ "size").extract[Array[Long]]
|
||||
new KMeansWrapper(pipeline, features, size, isLoaded = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.ml.r
|
|||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
|
|
|
@ -40,6 +40,10 @@ private[r] object RWrappers extends MLReader[Object] {
|
|||
case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" =>
|
||||
AFTSurvivalRegressionWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" =>
|
||||
GeneralizedLinearRegressionWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.KMeansWrapper" =>
|
||||
KMeansWrapper.load(path)
|
||||
case _ =>
|
||||
throw new SparkException(s"SparkR ml.load does not support load $className")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue