[SPARK-16137][SPARKR] randomForest for R
## What changes were proposed in this pull request? Random Forest Regression and Classification for R Clean-up/reordering generics.R ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15607 from felixcheung/rrandomforest.
This commit is contained in:
parent
2881a2d1d1
commit
b6879b8b35
|
@ -44,7 +44,8 @@ exportMethods("glm",
|
|||
"spark.gaussianMixture",
|
||||
"spark.als",
|
||||
"spark.kstest",
|
||||
"spark.logit")
|
||||
"spark.logit",
|
||||
"spark.randomForest")
|
||||
|
||||
# Job group lifecycle management methods
|
||||
export("setJobGroup",
|
||||
|
@ -350,7 +351,9 @@ export("as.DataFrame",
|
|||
"uncacheTable",
|
||||
"print.summary.GeneralizedLinearRegressionModel",
|
||||
"read.ml",
|
||||
"print.summary.KSTest")
|
||||
"print.summary.KSTest",
|
||||
"print.summary.RandomForestRegressionModel",
|
||||
"print.summary.RandomForestClassificationModel")
|
||||
|
||||
export("structField",
|
||||
"structField.jobj",
|
||||
|
@ -375,6 +378,8 @@ S3method(print, structField)
|
|||
S3method(print, structType)
|
||||
S3method(print, summary.GeneralizedLinearRegressionModel)
|
||||
S3method(print, summary.KSTest)
|
||||
S3method(print, summary.RandomForestRegressionModel)
|
||||
S3method(print, summary.RandomForestClassificationModel)
|
||||
S3method(structField, character)
|
||||
S3method(structField, jobj)
|
||||
S3method(structType, jobj)
|
||||
|
|
|
@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
|
|||
#' @export
|
||||
setGeneric("year", function(x) { standardGeneric("year") })
|
||||
|
||||
#' @rdname spark.glm
|
||||
###################### Spark.ML Methods ##########################
|
||||
|
||||
#' @rdname fitted
|
||||
#' @export
|
||||
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
|
||||
setGeneric("fitted")
|
||||
|
||||
#' @param x,y For \code{glm}: logical values indicating whether the response vector
|
||||
#' and model matrix used in the fitting process should be returned as
|
||||
|
@ -1332,13 +1334,38 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
|
|||
#' @export
|
||||
setGeneric("rbind", signature = "...")
|
||||
|
||||
#' @rdname spark.als
|
||||
#' @export
|
||||
setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
|
||||
|
||||
#' @rdname spark.gaussianMixture
|
||||
#' @export
|
||||
setGeneric("spark.gaussianMixture",
|
||||
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
|
||||
|
||||
#' @rdname spark.glm
|
||||
#' @export
|
||||
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
|
||||
|
||||
#' @rdname spark.isoreg
|
||||
#' @export
|
||||
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
||||
|
||||
#' @rdname spark.kmeans
|
||||
#' @export
|
||||
setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") })
|
||||
|
||||
#' @rdname fitted
|
||||
#' @rdname spark.kstest
|
||||
#' @export
|
||||
setGeneric("fitted")
|
||||
setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") })
|
||||
|
||||
#' @rdname spark.lda
|
||||
#' @export
|
||||
setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
|
||||
|
||||
#' @rdname spark.logit
|
||||
#' @export
|
||||
setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
|
||||
|
||||
#' @rdname spark.mlp
|
||||
#' @export
|
||||
|
@ -1348,14 +1375,15 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") })
|
|||
#' @export
|
||||
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
|
||||
|
||||
#' @rdname spark.randomForest
|
||||
#' @export
|
||||
setGeneric("spark.randomForest",
|
||||
function(data, formula, ...) { standardGeneric("spark.randomForest") })
|
||||
|
||||
#' @rdname spark.survreg
|
||||
#' @export
|
||||
setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
|
||||
|
||||
#' @rdname spark.lda
|
||||
#' @export
|
||||
setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
|
||||
|
||||
#' @rdname spark.lda
|
||||
#' @export
|
||||
setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
|
||||
|
@ -1364,20 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark
|
|||
#' @export
|
||||
setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
|
||||
|
||||
#' @rdname spark.isoreg
|
||||
#' @export
|
||||
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
|
||||
|
||||
#' @rdname spark.gaussianMixture
|
||||
#' @export
|
||||
setGeneric("spark.gaussianMixture",
|
||||
function(data, formula, ...) {
|
||||
standardGeneric("spark.gaussianMixture")
|
||||
})
|
||||
|
||||
#' @rdname spark.logit
|
||||
#' @export
|
||||
setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
|
||||
|
||||
#' @param object a fitted ML model object.
|
||||
#' @param path the directory where the model is saved.
|
||||
|
@ -1385,11 +1399,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.
|
|||
#' @rdname write.ml
|
||||
#' @export
|
||||
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
|
||||
|
||||
#' @rdname spark.als
|
||||
#' @export
|
||||
setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
|
||||
|
||||
#' @rdname spark.kstest
|
||||
#' @export
|
||||
setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") })
|
||||
|
|
252
R/pkg/R/mllib.R
252
R/pkg/R/mllib.R
|
@ -102,6 +102,20 @@ setClass("KSTest", representation(jobj = "jobj"))
|
|||
#' @note LogisticRegressionModel since 2.1.0
|
||||
setClass("LogisticRegressionModel", representation(jobj = "jobj"))
|
||||
|
||||
#' S4 class that represents a RandomForestRegressionModel
|
||||
#'
|
||||
#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel
|
||||
#' @export
|
||||
#' @note RandomForestRegressionModel since 2.1.0
|
||||
setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
|
||||
|
||||
#' S4 class that represents a RandomForestClassificationModel
|
||||
#'
|
||||
#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel
|
||||
#' @export
|
||||
#' @note RandomForestClassificationModel since 2.1.0
|
||||
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
|
||||
|
||||
#' Saves the MLlib model to the input path
|
||||
#'
|
||||
#' Saves the MLlib model to the input path. For more information, see the specific
|
||||
|
@ -112,7 +126,7 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj"))
|
|||
#' @seealso \link{spark.glm}, \link{glm},
|
||||
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
|
||||
#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
|
||||
#' @seealso \link{spark.survreg}
|
||||
#' @seealso \link{spark.randomForest}, \link{spark.survreg},
|
||||
#' @seealso \link{read.ml}
|
||||
NULL
|
||||
|
||||
|
@ -125,7 +139,8 @@ NULL
|
|||
#' @export
|
||||
#' @seealso \link{spark.glm}, \link{glm},
|
||||
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
|
||||
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
|
||||
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
|
||||
#' @seealso \link{spark.randomForest}, \link{spark.survreg}
|
||||
NULL
|
||||
|
||||
write_internal <- function(object, path, overwrite = FALSE) {
|
||||
|
@ -1122,6 +1137,10 @@ read.ml <- function(path) {
|
|||
new("ALSModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
|
||||
new("LogisticRegressionModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) {
|
||||
new("RandomForestRegressionModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
|
||||
new("RandomForestClassificationModel", jobj = jobj)
|
||||
} else {
|
||||
stop("Unsupported model: ", jobj)
|
||||
}
|
||||
|
@ -1617,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) {
|
|||
cat(summaryStr, "\n")
|
||||
invisible(x)
|
||||
}
|
||||
|
||||
#' Random Forest Model for Regression and Classification
|
||||
#'
|
||||
#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on
|
||||
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest
|
||||
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
|
||||
#' save/load fitted models.
|
||||
#' For more details, see
|
||||
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest}
|
||||
#'
|
||||
#' @param data a SparkDataFrame for training.
|
||||
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
|
||||
#' operators are supported, including '~', ':', '+', and '-'.
|
||||
#' @param type type of model, one of "regression" or "classification", to fit
|
||||
#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5)
|
||||
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
|
||||
#' how to split on features at each node. More bins give higher granularity. Must be
|
||||
#' >= 2 and >= number of categories in any categorical feature. (default = 32)
|
||||
#' @param numTrees Number of trees to train (>= 1).
|
||||
#' @param impurity Criterion used for information gain calculation.
|
||||
#' For regression, must be "variance". For classification, must be one of
|
||||
#' "entropy" and "gini". (default = gini)
|
||||
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
|
||||
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
|
||||
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
||||
#' @param featureSubsetStrategy The number of features to consider for splits at each tree node.
|
||||
#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n].
|
||||
#' @param seed integer seed for random number generation.
|
||||
#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in
|
||||
#' range (0, 1]. (default = 1.0)
|
||||
#' @param probabilityCol column name for predicted class conditional probabilities, only for
|
||||
#' classification. (default = "probability")
|
||||
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
|
||||
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
|
||||
#' nodes.
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @aliases spark.randomForest,SparkDataFrame,formula-method
|
||||
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
|
||||
#' @rdname spark.randomForest
|
||||
#' @name spark.randomForest
|
||||
#' @export
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' # fit a Random Forest Regression Model
|
||||
#' df <- createDataFrame(longley)
|
||||
#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
|
||||
#'
|
||||
#' # get the summary of the model
|
||||
#' summary(model)
|
||||
#'
|
||||
#' # make predictions
|
||||
#' predictions <- predict(model, df)
|
||||
#'
|
||||
#' # save and load the model
|
||||
#' path <- "path/to/model"
|
||||
#' write.ml(model, path)
|
||||
#' savedModel <- read.ml(path)
|
||||
#' summary(savedModel)
|
||||
#'
|
||||
#' # fit a Random Forest Classification Model
|
||||
#' df <- createDataFrame(iris)
|
||||
#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification")
|
||||
#' }
|
||||
#' @note spark.randomForest since 2.1.0
|
||||
setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, type = c("regression", "classification"),
|
||||
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
|
||||
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
|
||||
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
|
||||
probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) {
|
||||
type <- match.arg(type)
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
if (!is.null(seed)) {
|
||||
seed <- as.character(as.integer(seed))
|
||||
}
|
||||
switch(type,
|
||||
regression = {
|
||||
if (is.null(impurity)) impurity <- "variance"
|
||||
impurity <- match.arg(impurity, "variance")
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper",
|
||||
"fit", data@sdf, formula, as.integer(maxDepth),
|
||||
as.integer(maxBins), as.integer(numTrees),
|
||||
impurity, as.integer(minInstancesPerNode),
|
||||
as.numeric(minInfoGain), as.integer(checkpointInterval),
|
||||
as.character(featureSubsetStrategy), seed,
|
||||
as.numeric(subsamplingRate),
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
|
||||
new("RandomForestRegressionModel", jobj = jobj)
|
||||
},
|
||||
classification = {
|
||||
if (is.null(impurity)) impurity <- "gini"
|
||||
impurity <- match.arg(impurity, c("gini", "entropy"))
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
|
||||
"fit", data@sdf, formula, as.integer(maxDepth),
|
||||
as.integer(maxBins), as.integer(numTrees),
|
||||
impurity, as.integer(minInstancesPerNode),
|
||||
as.numeric(minInfoGain), as.integer(checkpointInterval),
|
||||
as.character(featureSubsetStrategy), seed,
|
||||
as.numeric(subsamplingRate), as.character(probabilityCol),
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
|
||||
new("RandomForestClassificationModel", jobj = jobj)
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
# Makes predictions from a Random Forest Regression model or Classification model
|
||||
|
||||
#' @param newData a SparkDataFrame for testing.
|
||||
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
|
||||
#' "prediction"
|
||||
#' @rdname spark.randomForest
|
||||
#' @aliases predict,RandomForestRegressionModel-method
|
||||
#' @export
|
||||
#' @note predict(randomForestRegressionModel) since 2.1.0
|
||||
setMethod("predict", signature(object = "RandomForestRegressionModel"),
|
||||
function(object, newData) {
|
||||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
#' @rdname spark.randomForest
|
||||
#' @aliases predict,RandomForestClassificationModel-method
|
||||
#' @export
|
||||
#' @note predict(randomForestClassificationModel) since 2.1.0
|
||||
setMethod("predict", signature(object = "RandomForestClassificationModel"),
|
||||
function(object, newData) {
|
||||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
# Save the Random Forest Regression or Classification model to the input path.
|
||||
|
||||
#' @param object A fitted Random Forest regression model or classification model
|
||||
#' @param path The directory where the model is saved
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
#'
|
||||
#' @aliases write.ml,RandomForestRegressionModel,character-method
|
||||
#' @rdname spark.randomForest
|
||||
#' @export
|
||||
#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0
|
||||
setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
||||
#' @aliases write.ml,RandomForestClassificationModel,character-method
|
||||
#' @rdname spark.randomForest
|
||||
#' @export
|
||||
#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0
|
||||
setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
||||
# Get the summary of an RandomForestRegressionModel model
|
||||
summary.randomForest <- function(model) {
|
||||
jobj <- model@jobj
|
||||
formula <- callJMethod(jobj, "formula")
|
||||
numFeatures <- callJMethod(jobj, "numFeatures")
|
||||
features <- callJMethod(jobj, "features")
|
||||
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
|
||||
numTrees <- callJMethod(jobj, "numTrees")
|
||||
treeWeights <- callJMethod(jobj, "treeWeights")
|
||||
list(formula = formula,
|
||||
numFeatures = numFeatures,
|
||||
features = features,
|
||||
featureImportances = featureImportances,
|
||||
numTrees = numTrees,
|
||||
treeWeights = treeWeights,
|
||||
jobj = jobj)
|
||||
}
|
||||
|
||||
#' @return \code{summary} returns the model's features as lists, depth and number of nodes
|
||||
#' or number of classes.
|
||||
#' @rdname spark.randomForest
|
||||
#' @aliases summary,RandomForestRegressionModel-method
|
||||
#' @export
|
||||
#' @note summary(RandomForestRegressionModel) since 2.1.0
|
||||
setMethod("summary", signature(object = "RandomForestRegressionModel"),
|
||||
function(object) {
|
||||
ans <- summary.randomForest(object)
|
||||
class(ans) <- "summary.RandomForestRegressionModel"
|
||||
ans
|
||||
})
|
||||
|
||||
# Get the summary of an RandomForestClassificationModel model
|
||||
|
||||
#' @rdname spark.randomForest
|
||||
#' @aliases summary,RandomForestClassificationModel-method
|
||||
#' @export
|
||||
#' @note summary(RandomForestClassificationModel) since 2.1.0
|
||||
setMethod("summary", signature(object = "RandomForestClassificationModel"),
|
||||
function(object) {
|
||||
ans <- summary.randomForest(object)
|
||||
class(ans) <- "summary.RandomForestClassificationModel"
|
||||
ans
|
||||
})
|
||||
|
||||
# Prints the summary of Random Forest Regression Model
|
||||
print.summary.randomForest <- function(x) {
|
||||
jobj <- x$jobj
|
||||
cat("Formula: ", x$formula)
|
||||
cat("\nNumber of features: ", x$numFeatures)
|
||||
cat("\nFeatures: ", unlist(x$features))
|
||||
cat("\nFeature importances: ", x$featureImportances)
|
||||
cat("\nNumber of trees: ", x$numTrees)
|
||||
cat("\nTree weights: ", unlist(x$treeWeights))
|
||||
|
||||
summaryStr <- callJMethod(jobj, "summary")
|
||||
cat("\n", summaryStr, "\n")
|
||||
invisible(x)
|
||||
}
|
||||
|
||||
#' @param x summary object of Random Forest regression model or classification model
|
||||
#' returned by \code{summary}.
|
||||
#' @rdname spark.randomForest
|
||||
#' @export
|
||||
#' @note print.summary.RandomForestRegressionModel since 2.1.0
|
||||
print.summary.RandomForestRegressionModel <- function(x, ...) {
|
||||
print.summary.randomForest(x)
|
||||
}
|
||||
|
||||
# Prints the summary of Random Forest Classification Model
|
||||
|
||||
#' @rdname spark.randomForest
|
||||
#' @export
|
||||
#' @note print.summary.RandomForestClassificationModel since 2.1.0
|
||||
print.summary.RandomForestClassificationModel <- function(x, ...) {
|
||||
print.summary.randomForest(x)
|
||||
}
|
||||
|
|
|
@ -871,4 +871,72 @@ test_that("spark.kstest", {
|
|||
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
|
||||
})
|
||||
|
||||
test_that("spark.randomForest Regression", {
|
||||
data <- suppressWarnings(createDataFrame(longley))
|
||||
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
|
||||
numTrees = 1)
|
||||
|
||||
predictions <- collect(predict(model, data))
|
||||
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
|
||||
63.221, 63.639, 64.989, 63.761,
|
||||
66.019, 67.857, 68.169, 66.513,
|
||||
68.655, 69.564, 69.331, 70.551),
|
||||
tolerance = 1e-4)
|
||||
|
||||
stats <- summary(model)
|
||||
expect_equal(stats$numTrees, 1)
|
||||
expect_error(capture.output(stats), NA)
|
||||
expect_true(length(capture.output(stats)) > 6)
|
||||
|
||||
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
|
||||
numTrees = 20, seed = 123)
|
||||
predictions <- collect(predict(model, data))
|
||||
expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
|
||||
63.736, 64.296, 64.868, 64.300,
|
||||
66.709, 67.697, 67.966, 67.252,
|
||||
68.866, 69.593, 69.195, 69.658),
|
||||
tolerance = 1e-4)
|
||||
stats <- summary(model)
|
||||
expect_equal(stats$numTrees, 20)
|
||||
|
||||
modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
|
||||
write.ml(model, modelPath)
|
||||
expect_error(write.ml(model, modelPath))
|
||||
write.ml(model, modelPath, overwrite = TRUE)
|
||||
model2 <- read.ml(modelPath)
|
||||
stats2 <- summary(model2)
|
||||
expect_equal(stats$formula, stats2$formula)
|
||||
expect_equal(stats$numFeatures, stats2$numFeatures)
|
||||
expect_equal(stats$features, stats2$features)
|
||||
expect_equal(stats$featureImportances, stats2$featureImportances)
|
||||
expect_equal(stats$numTrees, stats2$numTrees)
|
||||
expect_equal(stats$treeWeights, stats2$treeWeights)
|
||||
|
||||
unlink(modelPath)
|
||||
})
|
||||
|
||||
test_that("spark.randomForest Classification", {
|
||||
data <- suppressWarnings(createDataFrame(iris))
|
||||
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
|
||||
maxDepth = 5, maxBins = 16)
|
||||
|
||||
stats <- summary(model)
|
||||
expect_equal(stats$numFeatures, 2)
|
||||
expect_equal(stats$numTrees, 20)
|
||||
expect_error(capture.output(stats), NA)
|
||||
expect_true(length(capture.output(stats)) > 6)
|
||||
|
||||
modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
|
||||
write.ml(model, modelPath)
|
||||
expect_error(write.ml(model, modelPath))
|
||||
write.ml(model, modelPath, overwrite = TRUE)
|
||||
model2 <- read.ml(modelPath)
|
||||
stats2 <- summary(model2)
|
||||
expect_equal(stats$depth, stats2$depth)
|
||||
expect_equal(stats$numNodes, stats2$numNodes)
|
||||
expect_equal(stats$numClasses, stats2$numClasses)
|
||||
|
||||
unlink(modelPath)
|
||||
})
|
||||
|
||||
sparkR.session.stop()
|
||||
|
|
|
@ -56,6 +56,10 @@ private[r] object RWrappers extends MLReader[Object] {
|
|||
ALSWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
|
||||
LogisticRegressionWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.RandomForestRegressorWrapper" =>
|
||||
RandomForestRegressorWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
|
||||
RandomForestClassifierWrapper.load(path)
|
||||
case _ =>
|
||||
throw new SparkException(s"SparkR read.ml does not support load $className")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
private[r] class RandomForestClassifierWrapper private (
|
||||
val pipeline: PipelineModel,
|
||||
val formula: String,
|
||||
val features: Array[String]) extends MLWritable {
|
||||
|
||||
private val DTModel: RandomForestClassificationModel =
|
||||
pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
|
||||
|
||||
lazy val numFeatures: Int = DTModel.numFeatures
|
||||
lazy val featureImportances: Vector = DTModel.featureImportances
|
||||
lazy val numTrees: Int = DTModel.getNumTrees
|
||||
lazy val treeWeights: Array[Double] = DTModel.treeWeights
|
||||
|
||||
def summary: String = DTModel.toDebugString
|
||||
|
||||
def transform(dataset: Dataset[_]): DataFrame = {
|
||||
pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new
|
||||
RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] {
|
||||
def fit( // scalastyle:ignore
|
||||
data: DataFrame,
|
||||
formula: String,
|
||||
maxDepth: Int,
|
||||
maxBins: Int,
|
||||
numTrees: Int,
|
||||
impurity: String,
|
||||
minInstancesPerNode: Int,
|
||||
minInfoGain: Double,
|
||||
checkpointInterval: Int,
|
||||
featureSubsetStrategy: String,
|
||||
seed: String,
|
||||
subsamplingRate: Double,
|
||||
probabilityCol: String,
|
||||
maxMemoryInMB: Int,
|
||||
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
RWrapperUtils.checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
// get feature names from output schema
|
||||
val schema = rFormulaModel.transform(data).schema
|
||||
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||
.attributes.get
|
||||
val features = featureAttrs.map(_.name.get)
|
||||
|
||||
// assemble and fit the pipeline
|
||||
val rfc = new RandomForestClassifier()
|
||||
.setMaxDepth(maxDepth)
|
||||
.setMaxBins(maxBins)
|
||||
.setNumTrees(numTrees)
|
||||
.setImpurity(impurity)
|
||||
.setMinInstancesPerNode(minInstancesPerNode)
|
||||
.setMinInfoGain(minInfoGain)
|
||||
.setCheckpointInterval(checkpointInterval)
|
||||
.setFeatureSubsetStrategy(featureSubsetStrategy)
|
||||
.setSubsamplingRate(subsamplingRate)
|
||||
.setMaxMemoryInMB(maxMemoryInMB)
|
||||
.setCacheNodeIds(cacheNodeIds)
|
||||
.setProbabilityCol(probabilityCol)
|
||||
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(rFormulaModel, rfc))
|
||||
.fit(data)
|
||||
|
||||
new RandomForestClassifierWrapper(pipeline, formula, features)
|
||||
}
|
||||
|
||||
override def read: MLReader[RandomForestClassifierWrapper] =
|
||||
new RandomForestClassifierWrapperReader
|
||||
|
||||
override def load(path: String): RandomForestClassifierWrapper = super.load(path)
|
||||
|
||||
class RandomForestClassifierWrapperWriter(instance: RandomForestClassifierWrapper)
|
||||
extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("formula" -> instance.formula) ~
|
||||
("features" -> instance.features.toSeq)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] {
|
||||
|
||||
override def load(path: String): RandomForestClassifierWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val formula = (rMetadata \ "formula").extract[String]
|
||||
val features = (rMetadata \ "features").extract[Array[String]]
|
||||
|
||||
new RandomForestClassifierWrapper(pipeline, formula, features)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
private[r] class RandomForestRegressorWrapper private (
|
||||
val pipeline: PipelineModel,
|
||||
val formula: String,
|
||||
val features: Array[String]) extends MLWritable {
|
||||
|
||||
private val DTModel: RandomForestRegressionModel =
|
||||
pipeline.stages(1).asInstanceOf[RandomForestRegressionModel]
|
||||
|
||||
lazy val numFeatures: Int = DTModel.numFeatures
|
||||
lazy val featureImportances: Vector = DTModel.featureImportances
|
||||
lazy val numTrees: Int = DTModel.getNumTrees
|
||||
lazy val treeWeights: Array[Double] = DTModel.treeWeights
|
||||
|
||||
def summary: String = DTModel.toDebugString
|
||||
|
||||
def transform(dataset: Dataset[_]): DataFrame = {
|
||||
pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new
|
||||
RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRegressorWrapper] {
|
||||
def fit( // scalastyle:ignore
|
||||
data: DataFrame,
|
||||
formula: String,
|
||||
maxDepth: Int,
|
||||
maxBins: Int,
|
||||
numTrees: Int,
|
||||
impurity: String,
|
||||
minInstancesPerNode: Int,
|
||||
minInfoGain: Double,
|
||||
checkpointInterval: Int,
|
||||
featureSubsetStrategy: String,
|
||||
seed: String,
|
||||
subsamplingRate: Double,
|
||||
maxMemoryInMB: Int,
|
||||
cacheNodeIds: Boolean): RandomForestRegressorWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
RWrapperUtils.checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
// get feature names from output schema
|
||||
val schema = rFormulaModel.transform(data).schema
|
||||
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||
.attributes.get
|
||||
val features = featureAttrs.map(_.name.get)
|
||||
|
||||
// assemble and fit the pipeline
|
||||
val rfr = new RandomForestRegressor()
|
||||
.setMaxDepth(maxDepth)
|
||||
.setMaxBins(maxBins)
|
||||
.setNumTrees(numTrees)
|
||||
.setImpurity(impurity)
|
||||
.setMinInstancesPerNode(minInstancesPerNode)
|
||||
.setMinInfoGain(minInfoGain)
|
||||
.setCheckpointInterval(checkpointInterval)
|
||||
.setFeatureSubsetStrategy(featureSubsetStrategy)
|
||||
.setSubsamplingRate(subsamplingRate)
|
||||
.setMaxMemoryInMB(maxMemoryInMB)
|
||||
.setCacheNodeIds(cacheNodeIds)
|
||||
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||
if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong)
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(rFormulaModel, rfr))
|
||||
.fit(data)
|
||||
|
||||
new RandomForestRegressorWrapper(pipeline, formula, features)
|
||||
}
|
||||
|
||||
override def read: MLReader[RandomForestRegressorWrapper] = new RandomForestRegressorWrapperReader
|
||||
|
||||
override def load(path: String): RandomForestRegressorWrapper = super.load(path)
|
||||
|
||||
class RandomForestRegressorWrapperWriter(instance: RandomForestRegressorWrapper)
|
||||
extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("formula" -> instance.formula) ~
|
||||
("features" -> instance.features.toSeq)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] {
|
||||
|
||||
override def load(path: String): RandomForestRegressorWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val formula = (rMetadata \ "formula").extract[String]
|
||||
val features = (rMetadata \ "features").extract[Array[String]]
|
||||
|
||||
new RandomForestRegressorWrapper(pipeline, formula, features)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue