[SPARK-17157][SPARKR] Add multiclass logistic regression SparkR Wrapper
## What changes were proposed in this pull request? As we discussed in #14818, I added a separate R wrapper spark.logit for logistic regression. This single interface supports both binary and multinomial logistic regression. It also has "predict" and "summary" for binary logistic regression. ## How was this patch tested? New unit tests are added. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #15365 from wangmiao1981/glm.
This commit is contained in:
parent
5b7d403c18
commit
29cea8f332
|
@ -43,7 +43,8 @@ exportMethods("glm",
|
|||
"spark.isoreg",
|
||||
"spark.gaussianMixture",
|
||||
"spark.als",
|
||||
"spark.kstest")
|
||||
"spark.kstest",
|
||||
"spark.logit")
|
||||
|
||||
# Job group lifecycle management methods
|
||||
export("setJobGroup",
|
||||
|
|
|
@ -1375,6 +1375,10 @@ setGeneric("spark.gaussianMixture",
|
|||
standardGeneric("spark.gaussianMixture")
|
||||
})
|
||||
|
||||
#' @rdname spark.logit
|
||||
#' @export
|
||||
setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
|
||||
|
||||
#' @param object a fitted ML model object.
|
||||
#' @param path the directory where the model is saved.
|
||||
#' @param ... additional argument(s) passed to the method.
|
||||
|
|
192
R/pkg/R/mllib.R
192
R/pkg/R/mllib.R
|
@ -95,6 +95,13 @@ setClass("ALSModel", representation(jobj = "jobj"))
|
|||
#' @note KSTest since 2.1.0
|
||||
setClass("KSTest", representation(jobj = "jobj"))
|
||||
|
||||
#' S4 class that represents an LogisticRegressionModel
|
||||
#'
|
||||
#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel
|
||||
#' @export
|
||||
#' @note LogisticRegressionModel since 2.1.0
|
||||
setClass("LogisticRegressionModel", representation(jobj = "jobj"))
|
||||
|
||||
#' Saves the MLlib model to the input path
|
||||
#'
|
||||
#' Saves the MLlib model to the input path. For more information, see the specific
|
||||
|
@ -105,7 +112,7 @@ setClass("KSTest", representation(jobj = "jobj"))
|
|||
#' @seealso \link{spark.glm}, \link{glm},
|
||||
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
|
||||
#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
|
||||
#' @seealso \link{read.ml}
|
||||
#' @seealso \link{spark.logit}, \link{read.ml}
|
||||
NULL
|
||||
|
||||
#' Makes predictions from a MLlib model
|
||||
|
@ -117,7 +124,7 @@ NULL
|
|||
#' @export
|
||||
#' @seealso \link{spark.glm}, \link{glm},
|
||||
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
|
||||
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
|
||||
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.logit}
|
||||
NULL
|
||||
|
||||
write_internal <- function(object, path, overwrite = FALSE) {
|
||||
|
@ -647,6 +654,170 @@ setMethod("predict", signature(object = "KMeansModel"),
|
|||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
#' Logistic Regression Model
|
||||
#'
|
||||
#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression
|
||||
#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet.
|
||||
#' Users can print, make predictions on the produced model and save the model to the input path.
|
||||
#'
|
||||
#' @param data SparkDataFrame for training
|
||||
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
|
||||
#' operators are supported, including '~', '.', ':', '+', and '-'.
|
||||
#' @param regParam the regularization parameter. Default is 0.0.
|
||||
#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty.
|
||||
#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination
|
||||
#' of L1 and L2. Default is 0.0 which is an L2 penalty.
|
||||
#' @param maxIter maximum iteration number.
|
||||
#' @param tol convergence tolerance of iterations.
|
||||
#' @param fitIntercept whether to fit an intercept term. Default is TRUE.
|
||||
#' @param family the name of family which is a description of the label distribution to be used in the model.
|
||||
#' Supported options:
|
||||
#' \itemize{
|
||||
#' \item{"auto": Automatically select the family based on the number of classes:
|
||||
#' If number of classes == 1 || number of classes == 2, set to "binomial".
|
||||
#' Else, set to "multinomial".}
|
||||
#' \item{"binomial": Binary logistic regression with pivoting.}
|
||||
#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.
|
||||
#' Default is "auto".}
|
||||
#' }
|
||||
#' @param standardization whether to standardize the training features before fitting the model. The coefficients
|
||||
#' of models will be always returned on the original scale, so it will be transparent for
|
||||
#' users. Note that with/without standardization, the models should be always converged
|
||||
#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet.
|
||||
#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1
|
||||
#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0
|
||||
#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with
|
||||
#' threshold p is equivalent to setting thresholds c(1-p, p). When threshold is set, any user-set
|
||||
#' value for thresholds will be cleared. If both threshold and thresholds are set, then they must be
|
||||
#' equivalent. In multiclass (or binary) classification to adjust the probability of
|
||||
#' predicting each class. Array must have length equal to the number of classes, with values > 0,
|
||||
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
|
||||
#' is the original probability of that class and t is the class's threshold. Note: When thresholds
|
||||
#' is set, any user-set value for threshold will be cleared. If both threshold and thresholds are
|
||||
#' set, then they must be equivalent. Default is 0.5.
|
||||
#' @param weightCol The weight column name.
|
||||
#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions
|
||||
#' are large, this param could be adjusted to a larger size. Default is 2.
|
||||
#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @return \code{spark.logit} returns a fitted logistic regression model
|
||||
#' @rdname spark.logit
|
||||
#' @aliases spark.logit,SparkDataFrame,formula-method
|
||||
#' @name spark.logit
|
||||
#' @export
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' sparkR.session()
|
||||
#' # binary logistic regression
|
||||
#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
|
||||
#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
|
||||
#' binary_data <- as.data.frame(cbind(label, feature))
|
||||
#' binary_df <- createDataFrame(binary_data)
|
||||
#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
|
||||
#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
|
||||
#'
|
||||
#' # summary of binary logistic regression
|
||||
#' blr_summary <- summary(blr_model)
|
||||
#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure"))
|
||||
#' # save fitted model to input path
|
||||
#' path <- "path/to/model"
|
||||
#' write.ml(blr_model, path)
|
||||
#'
|
||||
#' # can also read back the saved model and predict
|
||||
#' Note that summary deos not work on loaded model
|
||||
#' savedModel <- read.ml(path)
|
||||
#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction"))
|
||||
#'
|
||||
#' # multinomial logistic regression
|
||||
#'
|
||||
#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0)
|
||||
#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667)
|
||||
#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987)
|
||||
#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130)
|
||||
#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842)
|
||||
#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4))
|
||||
#' df <- createDataFrame(data)
|
||||
#'
|
||||
#' Note that summary of multinomial logistic regression is not implemented yet
|
||||
#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds=c(0, 1, 1))
|
||||
#' predict1 <- collect(select(predict(model, df), "prediction"))
|
||||
#' }
|
||||
#' @note spark.logit since 2.1.0
|
||||
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
|
||||
tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE,
|
||||
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
|
||||
probabilityCol = "probability") {
|
||||
formula <- paste0(deparse(formula), collapse = "")
|
||||
|
||||
if (is.null(weightCol)) {
|
||||
weightCol <- ""
|
||||
}
|
||||
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
|
||||
data@sdf, formula, as.numeric(regParam),
|
||||
as.numeric(elasticNetParam), as.integer(maxIter),
|
||||
as.numeric(tol), as.logical(fitIntercept),
|
||||
as.character(family), as.logical(standardization),
|
||||
as.array(thresholds), as.character(weightCol),
|
||||
as.integer(aggregationDepth), as.character(probabilityCol))
|
||||
new("LogisticRegressionModel", jobj = jobj)
|
||||
})
|
||||
|
||||
# Predicted values based on an LogisticRegressionModel model
|
||||
|
||||
#' @param newData a SparkDataFrame for testing.
|
||||
#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel.
|
||||
#' @rdname spark.logit
|
||||
#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method
|
||||
#' @export
|
||||
#' @note predict(LogisticRegressionModel) since 2.1.0
|
||||
setMethod("predict", signature(object = "LogisticRegressionModel"),
|
||||
function(object, newData) {
|
||||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
# Get the summary of an LogisticRegressionModel
|
||||
|
||||
#' @param object an LogisticRegressionModel fitted by \code{spark.logit}
|
||||
#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that
|
||||
#' Multinomial logistic regression summary is not available now.
|
||||
#' @rdname spark.logit
|
||||
#' @aliases summary,LogisticRegressionModel-method
|
||||
#' @export
|
||||
#' @note summary(LogisticRegressionModel) since 2.1.0
|
||||
setMethod("summary", signature(object = "LogisticRegressionModel"),
|
||||
function(object) {
|
||||
jobj <- object@jobj
|
||||
is.loaded <- callJMethod(jobj, "isLoaded")
|
||||
|
||||
if (is.loaded) {
|
||||
stop("Loaded model doesn't have training summary.")
|
||||
}
|
||||
|
||||
roc <- dataFrame(callJMethod(jobj, "roc"))
|
||||
|
||||
areaUnderROC <- callJMethod(jobj, "areaUnderROC")
|
||||
|
||||
pr <- dataFrame(callJMethod(jobj, "pr"))
|
||||
|
||||
fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold"))
|
||||
|
||||
precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold"))
|
||||
|
||||
recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold"))
|
||||
|
||||
totalIterations <- callJMethod(jobj, "totalIterations")
|
||||
|
||||
objectiveHistory <- callJMethod(jobj, "objectiveHistory")
|
||||
|
||||
list(roc = roc, areaUnderROC = areaUnderROC, pr = pr,
|
||||
fMeasureByThreshold = fMeasureByThreshold,
|
||||
precisionByThreshold = precisionByThreshold,
|
||||
recallByThreshold = recallByThreshold,
|
||||
totalIterations = totalIterations, objectiveHistory = objectiveHistory)
|
||||
})
|
||||
|
||||
#' Multilayer Perceptron Classification Model
|
||||
#'
|
||||
#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame.
|
||||
|
@ -888,6 +1059,21 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
|
|||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
||||
# Save fitted LogisticRegressionModel to the input path
|
||||
|
||||
#' @param path The directory where the model is saved
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
#'
|
||||
#' @rdname spark.logit
|
||||
#' @aliases write.ml,LogisticRegressionModel,character-method
|
||||
#' @export
|
||||
#' @note write.ml(LogisticRegression, character) since 2.1.0
|
||||
setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
||||
# Save fitted MLlib model to the input path
|
||||
|
||||
#' @param path the directory where the model is saved.
|
||||
|
@ -938,6 +1124,8 @@ read.ml <- function(path) {
|
|||
new("GaussianMixtureModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
|
||||
new("ALSModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
|
||||
new("LogisticRegressionModel", jobj = jobj)
|
||||
} else {
|
||||
stop("Unsupported model: ", jobj)
|
||||
}
|
||||
|
|
|
@ -602,6 +602,61 @@ test_that("spark.isotonicRegression", {
|
|||
unlink(modelPath)
|
||||
})
|
||||
|
||||
test_that("spark.logit", {
|
||||
# test binary logistic regression
|
||||
label <- c(1.0, 1.0, 1.0, 0.0, 0.0)
|
||||
feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776)
|
||||
binary_data <- as.data.frame(cbind(label, feature))
|
||||
binary_df <- createDataFrame(binary_data)
|
||||
|
||||
blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0)
|
||||
blr_predict <- collect(select(predict(blr_model, binary_df), "prediction"))
|
||||
expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0))
|
||||
blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0)
|
||||
blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction"))
|
||||
expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1))
|
||||
|
||||
# test summary of binary logistic regression
|
||||
blr_summary <- summary(blr_model)
|
||||
blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure"))
|
||||
expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487),
|
||||
tolerance = 1e-4)
|
||||
expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000),
|
||||
tolerance = 1e-4)
|
||||
blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision"))
|
||||
expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000),
|
||||
tolerance = 1e-4)
|
||||
blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall"))
|
||||
expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000),
|
||||
tolerance = 1e-4)
|
||||
|
||||
# test model save and read
|
||||
modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp")
|
||||
write.ml(blr_model, modelPath)
|
||||
expect_error(write.ml(blr_model, modelPath))
|
||||
write.ml(blr_model, modelPath, overwrite = TRUE)
|
||||
blr_model2 <- read.ml(modelPath)
|
||||
blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction"))
|
||||
expect_equal(blr_predict$prediction, blr_predict2$prediction)
|
||||
expect_error(summary(blr_model2))
|
||||
unlink(modelPath)
|
||||
|
||||
# test multinomial logistic regression
|
||||
label <- c(0.0, 1.0, 2.0, 0.0, 0.0)
|
||||
feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667)
|
||||
feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987)
|
||||
feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130)
|
||||
feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842)
|
||||
data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4))
|
||||
df <- createDataFrame(data)
|
||||
|
||||
model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1))
|
||||
predict1 <- collect(select(predict(model, df), "prediction"))
|
||||
expect_equal(predict1$prediction, c(0, 0, 0, 0, 0))
|
||||
# Summary of multinomial logistic regression is not implemented yet
|
||||
expect_error(summary(model))
|
||||
})
|
||||
|
||||
test_that("spark.gaussianMixture", {
|
||||
# R code to reproduce the result.
|
||||
# nolint start
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
private[r] class LogisticRegressionWrapper private (
|
||||
val pipeline: PipelineModel,
|
||||
val features: Array[String],
|
||||
val isLoaded: Boolean = false) extends MLWritable {
|
||||
|
||||
private val logisticRegressionModel: LogisticRegressionModel =
|
||||
pipeline.stages(1).asInstanceOf[LogisticRegressionModel]
|
||||
|
||||
lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations
|
||||
|
||||
lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory
|
||||
|
||||
lazy val blrSummary =
|
||||
logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary]
|
||||
|
||||
lazy val roc: DataFrame = blrSummary.roc
|
||||
|
||||
lazy val areaUnderROC: Double = blrSummary.areaUnderROC
|
||||
|
||||
lazy val pr: DataFrame = blrSummary.pr
|
||||
|
||||
lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold
|
||||
|
||||
lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold
|
||||
|
||||
lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold
|
||||
|
||||
def transform(dataset: Dataset[_]): DataFrame = {
|
||||
pipeline.transform(dataset).drop(logisticRegressionModel.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object LogisticRegressionWrapper
|
||||
extends MLReadable[LogisticRegressionWrapper] {
|
||||
|
||||
def fit( // scalastyle:ignore
|
||||
data: DataFrame,
|
||||
formula: String,
|
||||
regParam: Double,
|
||||
elasticNetParam: Double,
|
||||
maxIter: Int,
|
||||
tol: Double,
|
||||
fitIntercept: Boolean,
|
||||
family: String,
|
||||
standardization: Boolean,
|
||||
thresholds: Array[Double],
|
||||
weightCol: String,
|
||||
aggregationDepth: Int,
|
||||
probability: String
|
||||
): LogisticRegressionWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
RWrapperUtils.checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
// get feature names from output schema
|
||||
val schema = rFormulaModel.transform(data).schema
|
||||
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||
.attributes.get
|
||||
val features = featureAttrs.map(_.name.get)
|
||||
|
||||
// assemble and fit the pipeline
|
||||
val logisticRegression = new LogisticRegression()
|
||||
.setRegParam(regParam)
|
||||
.setElasticNetParam(elasticNetParam)
|
||||
.setMaxIter(maxIter)
|
||||
.setTol(tol)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFamily(family)
|
||||
.setStandardization(standardization)
|
||||
.setWeightCol(weightCol)
|
||||
.setAggregationDepth(aggregationDepth)
|
||||
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||
.setProbabilityCol(probability)
|
||||
|
||||
if (thresholds.length > 1) {
|
||||
logisticRegression.setThresholds(thresholds)
|
||||
} else {
|
||||
logisticRegression.setThreshold(thresholds(0))
|
||||
}
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(rFormulaModel, logisticRegression))
|
||||
.fit(data)
|
||||
|
||||
new LogisticRegressionWrapper(pipeline, features)
|
||||
}
|
||||
|
||||
override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader
|
||||
|
||||
override def load(path: String): LogisticRegressionWrapper = super.load(path)
|
||||
|
||||
class LogisticRegressionWrapperWriter(instance: LogisticRegressionWrapper) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("features" -> instance.features.toSeq)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class LogisticRegressionWrapperReader extends MLReader[LogisticRegressionWrapper] {
|
||||
|
||||
override def load(path: String): LogisticRegressionWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val features = (rMetadata \ "features").extract[Array[String]]
|
||||
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
new LogisticRegressionWrapper(pipeline, features, isLoaded = true)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -54,6 +54,8 @@ private[r] object RWrappers extends MLReader[Object] {
|
|||
GaussianMixtureWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.ALSWrapper" =>
|
||||
ALSWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
|
||||
LogisticRegressionWrapper.load(path)
|
||||
case _ =>
|
||||
throw new SparkException(s"SparkR read.ml does not support load $className")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue