[SPARK-30819][SPARKR][ML] Add FMRegressor wrapper to SparkR
### What changes were proposed in this pull request? This pull request adds SparkR wrapper for `FMRegressor`: - Supporting ` org.apache.spark.ml.r.FMRegressorWrapper`. - `FMRegressionModel` S4 class. - Corresponding `spark.fmRegressor`, `predict`, `summary` and `write.ml` generics. - Corresponding docs and tests. ### Why are the changes needed? Feature parity. ### Does this PR introduce any user-facing change? No (new API). ### How was this patch tested? New unit tests. Closes #27571 from zero323/SPARK-30819. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
61f903fa7a
commit
697fe911ac
|
@ -74,7 +74,8 @@ exportMethods("glm",
|
||||||
"spark.findFrequentSequentialPatterns",
|
"spark.findFrequentSequentialPatterns",
|
||||||
"spark.assignClusters",
|
"spark.assignClusters",
|
||||||
"spark.fmClassifier",
|
"spark.fmClassifier",
|
||||||
"spark.lm")
|
"spark.lm",
|
||||||
|
"spark.fmRegressor")
|
||||||
|
|
||||||
# Job group lifecycle management methods
|
# Job group lifecycle management methods
|
||||||
export("setJobGroup",
|
export("setJobGroup",
|
||||||
|
|
|
@ -1483,6 +1483,10 @@ setGeneric("spark.bisectingKmeans",
|
||||||
setGeneric("spark.fmClassifier",
|
setGeneric("spark.fmClassifier",
|
||||||
function(data, formula, ...) { standardGeneric("spark.fmClassifier") })
|
function(data, formula, ...) { standardGeneric("spark.fmClassifier") })
|
||||||
|
|
||||||
|
#' @rdname spark.fmRegressor
|
||||||
|
setGeneric("spark.fmRegressor",
|
||||||
|
function(data, formula, ...) { standardGeneric("spark.fmRegressor") })
|
||||||
|
|
||||||
#' @rdname spark.gaussianMixture
|
#' @rdname spark.gaussianMixture
|
||||||
setGeneric("spark.gaussianMixture",
|
setGeneric("spark.gaussianMixture",
|
||||||
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
|
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
|
||||||
|
|
|
@ -42,6 +42,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
|
||||||
#' @note LinearRegressionModel since 3.1.0
|
#' @note LinearRegressionModel since 3.1.0
|
||||||
setClass("LinearRegressionModel", representation(jobj = "jobj"))
|
setClass("LinearRegressionModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
|
#' S4 class that represents a FMRegressionModel
|
||||||
|
#'
|
||||||
|
#' @param jobj a Java object reference to the backing Scala FMRegressorWrapper
|
||||||
|
#' @note FMRegressionModel since 3.1.0
|
||||||
|
setClass("FMRegressionModel", representation(jobj = "jobj"))
|
||||||
|
|
||||||
#' Generalized Linear Models
|
#' Generalized Linear Models
|
||||||
#'
|
#'
|
||||||
#' Fits generalized linear model against a SparkDataFrame.
|
#' Fits generalized linear model against a SparkDataFrame.
|
||||||
|
@ -612,18 +618,22 @@ setMethod("spark.lm", signature(data = "SparkDataFrame", formula = "formula"),
|
||||||
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
|
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
|
||||||
"alphabetDesc", "alphabetAsc")) {
|
"alphabetDesc", "alphabetAsc")) {
|
||||||
|
|
||||||
|
|
||||||
formula <- paste(deparse(formula), collapse = "")
|
formula <- paste(deparse(formula), collapse = "")
|
||||||
|
|
||||||
|
|
||||||
solver <- match.arg(solver)
|
solver <- match.arg(solver)
|
||||||
loss <- match.arg(loss)
|
loss <- match.arg(loss)
|
||||||
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
|
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
|
||||||
|
|
||||||
|
|
||||||
if (!is.null(weightCol) && weightCol == "") {
|
if (!is.null(weightCol) && weightCol == "") {
|
||||||
weightCol <- NULL
|
weightCol <- NULL
|
||||||
} else if (!is.null(weightCol)) {
|
} else if (!is.null(weightCol)) {
|
||||||
weightCol <- as.character(weightCol)
|
weightCol <- as.character(weightCol)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
jobj <- callJStatic("org.apache.spark.ml.r.LinearRegressionWrapper",
|
jobj <- callJStatic("org.apache.spark.ml.r.LinearRegressionWrapper",
|
||||||
"fit",
|
"fit",
|
||||||
data@sdf,
|
data@sdf,
|
||||||
|
@ -642,8 +652,10 @@ setMethod("spark.lm", signature(data = "SparkDataFrame", formula = "formula"),
|
||||||
new("LinearRegressionModel", jobj = jobj)
|
new("LinearRegressionModel", jobj = jobj)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# Returns the summary of a Linear Regression model produced by \code{spark.lm}
|
# Returns the summary of a Linear Regression model produced by \code{spark.lm}
|
||||||
|
|
||||||
|
|
||||||
#' @param object a Linear Regression Model model fitted by \code{spark.lm}.
|
#' @param object a Linear Regression Model model fitted by \code{spark.lm}.
|
||||||
#' @return \code{summary} returns summary information of the fitted model, which is a list.
|
#' @return \code{summary} returns summary information of the fitted model, which is a list.
|
||||||
#'
|
#'
|
||||||
|
@ -659,14 +671,17 @@ setMethod("summary", signature(object = "LinearRegressionModel"),
|
||||||
rownames(coefficients) <- unlist(features)
|
rownames(coefficients) <- unlist(features)
|
||||||
numFeatures <- callJMethod(jobj, "numFeatures")
|
numFeatures <- callJMethod(jobj, "numFeatures")
|
||||||
|
|
||||||
|
|
||||||
list(
|
list(
|
||||||
coefficients = coefficients,
|
coefficients = coefficients,
|
||||||
numFeatures = numFeatures
|
numFeatures = numFeatures
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# Predicted values based on an LinearRegressionModel model
|
# Predicted values based on an LinearRegressionModel model
|
||||||
|
|
||||||
|
|
||||||
#' @param newData a SparkDataFrame for testing.
|
#' @param newData a SparkDataFrame for testing.
|
||||||
#' @return \code{predict} returns the predicted values based on a LinearRegressionModel.
|
#' @return \code{predict} returns the predicted values based on a LinearRegressionModel.
|
||||||
#'
|
#'
|
||||||
|
@ -678,8 +693,10 @@ setMethod("predict", signature(object = "LinearRegressionModel"),
|
||||||
predict_internal(object, newData)
|
predict_internal(object, newData)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# Save fitted LinearRegressionModel to the input path
|
# Save fitted LinearRegressionModel to the input path
|
||||||
|
|
||||||
|
|
||||||
#' @param path The directory where the model is saved.
|
#' @param path The directory where the model is saved.
|
||||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||||
#' which means throw exception if the output path exists.
|
#' which means throw exception if the output path exists.
|
||||||
|
@ -691,3 +708,158 @@ setMethod("write.ml", signature(object = "LinearRegressionModel", path = "charac
|
||||||
function(object, path, overwrite = FALSE) {
|
function(object, path, overwrite = FALSE) {
|
||||||
write_internal(object, path, overwrite)
|
write_internal(object, path, overwrite)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' Factorization Machines Regression Model
|
||||||
|
#'
|
||||||
|
#' \code{spark.fmRegressor} fits a factorization regression model against a SparkDataFrame.
|
||||||
|
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
|
||||||
|
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
|
||||||
|
#'
|
||||||
|
#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
|
||||||
|
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
|
||||||
|
#' operators are supported, including '~', '.', ':', '+', and '-'.
|
||||||
|
#' @param factorSize dimensionality of the factors.
|
||||||
|
#' @param fitLinear whether to fit linear term. # TODO Can we express this with formula?
|
||||||
|
#' @param regParam the regularization parameter.
|
||||||
|
#' @param miniBatchFraction the mini-batch fraction parameter.
|
||||||
|
#' @param initStd the standard deviation of initial coefficients.
|
||||||
|
#' @param maxIter maximum iteration number.
|
||||||
|
#' @param stepSize stepSize parameter.
|
||||||
|
#' @param tol convergence tolerance of iterations.
|
||||||
|
#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "adamW".
|
||||||
|
#' @param seed seed parameter for weights initialization.
|
||||||
|
#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
|
||||||
|
#' decide the base level of a string feature as the last category
|
||||||
|
#' after ordering is dropped when encoding strings. Supported options
|
||||||
|
#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
|
||||||
|
#' "alphabetAsc". The default value is "frequencyDesc". When the
|
||||||
|
#' ordering is set to "alphabetDesc", this drops the same category
|
||||||
|
#' as R when encoding strings.
|
||||||
|
#' @param ... additional arguments passed to the method.
|
||||||
|
#' @return \code{spark.fmRegressor} returns a fitted Factorization Machines Regression Model.
|
||||||
|
#'
|
||||||
|
#' @rdname spark.fmRegressor
|
||||||
|
#' @aliases spark.fmRegressor,SparkDataFrame,formula-method
|
||||||
|
#' @name spark.fmRegressor
|
||||||
|
#' @seealso \link{read.ml}
|
||||||
|
#' @examples
|
||||||
|
#' \dontrun{
|
||||||
|
#' df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
|
||||||
|
#'
|
||||||
|
#' # fit Factorization Machines Regression Model
|
||||||
|
#' model <- spark.fmRegressor(
|
||||||
|
#' df, label ~ features,
|
||||||
|
#' regParam = 0.01, maxIter = 10, fitLinear = TRUE
|
||||||
|
#' )
|
||||||
|
#'
|
||||||
|
#' # get the summary of the model
|
||||||
|
#' summary(model)
|
||||||
|
#'
|
||||||
|
#' # make predictions
|
||||||
|
#' predictions <- predict(model, df)
|
||||||
|
#'
|
||||||
|
#' # save and load the model
|
||||||
|
#' path <- "path/to/model"
|
||||||
|
#' write.ml(model, path)
|
||||||
|
#' savedModel <- read.ml(path)
|
||||||
|
#' summary(savedModel)
|
||||||
|
#' }
|
||||||
|
#' @note spark.fmRegressor since 3.1.0
|
||||||
|
setMethod("spark.fmRegressor", signature(data = "SparkDataFrame", formula = "formula"),
|
||||||
|
function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 0.0,
|
||||||
|
miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, stepSize=1.0,
|
||||||
|
tol = 1e-6, solver = c("adamW", "gd"), seed = NULL,
|
||||||
|
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
|
||||||
|
"alphabetDesc", "alphabetAsc")) {
|
||||||
|
|
||||||
|
|
||||||
|
formula <- paste(deparse(formula), collapse = "")
|
||||||
|
|
||||||
|
|
||||||
|
if (!is.null(seed)) {
|
||||||
|
seed <- as.character(as.integer(seed))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
solver <- match.arg(solver)
|
||||||
|
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
|
||||||
|
|
||||||
|
|
||||||
|
jobj <- callJStatic("org.apache.spark.ml.r.FMRegressorWrapper",
|
||||||
|
"fit",
|
||||||
|
data@sdf,
|
||||||
|
formula,
|
||||||
|
as.integer(factorSize),
|
||||||
|
as.logical(fitLinear),
|
||||||
|
as.numeric(regParam),
|
||||||
|
as.numeric(miniBatchFraction),
|
||||||
|
as.numeric(initStd),
|
||||||
|
as.integer(maxIter),
|
||||||
|
as.numeric(stepSize),
|
||||||
|
as.numeric(tol),
|
||||||
|
solver,
|
||||||
|
seed,
|
||||||
|
stringIndexerOrderType)
|
||||||
|
new("FMRegressionModel", jobj = jobj)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the summary of a FM Regression model produced by \code{spark.fmRegressor}
|
||||||
|
|
||||||
|
|
||||||
|
#' @param object a FM Regression Model model fitted by \code{spark.fmRegressor}.
|
||||||
|
#' @return \code{summary} returns summary information of the fitted model, which is a list.
|
||||||
|
#'
|
||||||
|
#' @rdname spark.fmRegressor
|
||||||
|
#' @note summary(FMRegressionModel) since 3.1.0
|
||||||
|
setMethod("summary", signature(object = "FMRegressionModel"),
|
||||||
|
function(object) {
|
||||||
|
jobj <- object@jobj
|
||||||
|
features <- callJMethod(jobj, "rFeatures")
|
||||||
|
coefficients <- callJMethod(jobj, "rCoefficients")
|
||||||
|
coefficients <- as.matrix(unlist(coefficients))
|
||||||
|
colnames(coefficients) <- c("Estimate")
|
||||||
|
rownames(coefficients) <- unlist(features)
|
||||||
|
numFeatures <- callJMethod(jobj, "numFeatures")
|
||||||
|
raw_factors <- unlist(callJMethod(jobj, "rFactors"))
|
||||||
|
factor_size <- callJMethod(jobj, "factorSize")
|
||||||
|
|
||||||
|
|
||||||
|
list(
|
||||||
|
coefficients = coefficients,
|
||||||
|
factors = matrix(raw_factors, ncol = factor_size),
|
||||||
|
numFeatures = numFeatures,
|
||||||
|
factorSize = factor_size
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Predicted values based on an FMRegressionModel model
|
||||||
|
|
||||||
|
|
||||||
|
#' @param newData a SparkDataFrame for testing.
|
||||||
|
#' @return \code{predict} returns the predicted values based on an FMRegressionModel.
|
||||||
|
#'
|
||||||
|
#' @rdname spark.fmRegressor
|
||||||
|
#' @aliases predict,FMRegressionModel,SparkDataFrame-method
|
||||||
|
#' @note predict(FMRegressionModel) since 3.1.0
|
||||||
|
setMethod("predict", signature(object = "FMRegressionModel"),
|
||||||
|
function(object, newData) {
|
||||||
|
predict_internal(object, newData)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Save fitted FMRegressionModel to the input path
|
||||||
|
|
||||||
|
|
||||||
|
#' @param path The directory where the model is saved.
|
||||||
|
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||||
|
#' which means throw exception if the output path exists.
|
||||||
|
#'
|
||||||
|
#' @rdname spark.fmRegressor
|
||||||
|
#' @aliases write.ml,FMRegressionModel,character-method
|
||||||
|
#' @note write.ml(FMRegressionModel, character) since 3.1.0
|
||||||
|
setMethod("write.ml", signature(object = "FMRegressionModel", path = "character"),
|
||||||
|
function(object, path, overwrite = FALSE) {
|
||||||
|
write_internal(object, path, overwrite)
|
||||||
|
})
|
||||||
|
|
|
@ -127,6 +127,8 @@ read.ml <- function(path) {
|
||||||
new("FMClassificationModel", jobj = jobj)
|
new("FMClassificationModel", jobj = jobj)
|
||||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearRegressionWrapper")) {
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearRegressionWrapper")) {
|
||||||
new("LinearRegressionModel", jobj = jobj)
|
new("LinearRegressionModel", jobj = jobj)
|
||||||
|
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMRegressorWrapper")) {
|
||||||
|
new("FMRegressionModel", jobj = jobj)
|
||||||
} else {
|
} else {
|
||||||
stop("Unsupported model: ", jobj)
|
stop("Unsupported model: ", jobj)
|
||||||
}
|
}
|
||||||
|
|
|
@ -580,4 +580,34 @@ test_that("spark.survreg", {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
test_that("spark.fmRegressor", {
|
||||||
|
df <- suppressWarnings(createDataFrame(iris))
|
||||||
|
|
||||||
|
model <- spark.fmRegressor(
|
||||||
|
df, Sepal_Width ~ .,
|
||||||
|
regParam = 0.01, maxIter = 10, fitLinear = TRUE
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction1 <- predict(model, df)
|
||||||
|
expect_is(prediction1, "SparkDataFrame")
|
||||||
|
|
||||||
|
# Test model save/load
|
||||||
|
if (windows_with_hadoop()) {
|
||||||
|
modelPath <- tempfile(pattern = "spark-fmregressor", fileext = ".tmp")
|
||||||
|
write.ml(model, modelPath)
|
||||||
|
model2 <- read.ml(modelPath)
|
||||||
|
|
||||||
|
expect_is(model2, "FMRegressionModel")
|
||||||
|
expect_equal(summary(model), summary(model2))
|
||||||
|
|
||||||
|
prediction2 <- predict(model2, df)
|
||||||
|
expect_equal(
|
||||||
|
collect(prediction1),
|
||||||
|
collect(prediction2)
|
||||||
|
)
|
||||||
|
unlink(modelPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
sparkR.session.stop()
|
sparkR.session.stop()
|
||||||
|
|
|
@ -535,6 +535,8 @@ SparkR supports the following machine learning models and algorithms.
|
||||||
|
|
||||||
* Linear Regression
|
* Linear Regression
|
||||||
|
|
||||||
|
* Factorization Machines (FM) Regressor
|
||||||
|
|
||||||
#### Tree - Classification and Regression
|
#### Tree - Classification and Regression
|
||||||
|
|
||||||
* Decision Tree
|
* Decision Tree
|
||||||
|
@ -847,6 +849,20 @@ predictions <- predict(model, carsDF)
|
||||||
head(select(predictions, predictions$prediction))
|
head(select(predictions, predictions$prediction))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Factorization Machines Regressor
|
||||||
|
|
||||||
|
Factorization Machines for regression problems.
|
||||||
|
|
||||||
|
For background and details about the implementation of factorization machines,
|
||||||
|
refer to the [Factorization Machines section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).
|
||||||
|
|
||||||
|
```{r}
|
||||||
|
model <- spark.fmRegressor(carsDF, mpg ~ wt + hp)
|
||||||
|
summary(model)
|
||||||
|
predictions <- predict(model, carsDF)
|
||||||
|
head(select(predictions, predictions$prediction))
|
||||||
|
```
|
||||||
|
|
||||||
#### Decision Tree
|
#### Decision Tree
|
||||||
|
|
||||||
`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`.
|
`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`.
|
||||||
|
|
|
@ -1101,6 +1101,15 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.
|
||||||
{% include_example python/ml/fm_regressor_example.py %}
|
{% include_example python/ml/fm_regressor_example.py %}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div data-lang="r" markdown="1">
|
||||||
|
|
||||||
|
Refer to the [R API documentation](api/R/spark.fmRegressor.html) for more details.
|
||||||
|
|
||||||
|
Note: At the moment SparkR doesn't suport feature scaling.
|
||||||
|
|
||||||
|
{% include_example r/ml/fmRegressor.R %}
|
||||||
|
</div>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -548,6 +548,7 @@ SparkR supports the following machine learning algorithms currently:
|
||||||
* [`spark.glm`](api/R/spark.glm.html) or [`glm`](api/R/glm.html): [`Generalized Linear Model (GLM)`](ml-classification-regression.html#generalized-linear-regression)
|
* [`spark.glm`](api/R/spark.glm.html) or [`glm`](api/R/glm.html): [`Generalized Linear Model (GLM)`](ml-classification-regression.html#generalized-linear-regression)
|
||||||
* [`spark.isoreg`](api/R/spark.isoreg.html): [`Isotonic Regression`](ml-classification-regression.html#isotonic-regression)
|
* [`spark.isoreg`](api/R/spark.isoreg.html): [`Isotonic Regression`](ml-classification-regression.html#isotonic-regression)
|
||||||
* [`spark.lm`](api/R/spark.lm.html): [`Linear Regression`](ml-classification-regression.html#linear-regression)
|
* [`spark.lm`](api/R/spark.lm.html): [`Linear Regression`](ml-classification-regression.html#linear-regression)
|
||||||
|
* [`spark.fmRegressor`](api/R/spark.fmRegressor.html): [`Factorization Machines regressor`](ml-classification-regression.html#factorization-machines-regressor)
|
||||||
|
|
||||||
#### Tree
|
#### Tree
|
||||||
|
|
||||||
|
|
45
examples/src/main/r/ml/fmRegressor.R
Normal file
45
examples/src/main/r/ml/fmRegressor.R
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# To run this example use
|
||||||
|
# ./bin/spark-submit examples/src/main/r/ml/fmRegressor.R
|
||||||
|
|
||||||
|
# Load SparkR library into your R session
|
||||||
|
library(SparkR)
|
||||||
|
|
||||||
|
# Initialize SparkSession
|
||||||
|
sparkR.session(appName = "SparkR-ML-fmRegressor-example")
|
||||||
|
|
||||||
|
# $example on$
|
||||||
|
# Load training data
|
||||||
|
df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
|
||||||
|
training_test <- randomSplit(df, c(0.7, 0.3))
|
||||||
|
training <- training_test[[1]]
|
||||||
|
test <- training_test[[2]]
|
||||||
|
|
||||||
|
# Fit a FM regression model
|
||||||
|
model <- spark.fmRegressor(training, label ~ features)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
|
summary(model)
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
predictions <- predict(model, test)
|
||||||
|
head(predictions)
|
||||||
|
# $example off$
|
||||||
|
|
||||||
|
sparkR.session.stop()
|
|
@ -0,0 +1,155 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.r
|
||||||
|
|
||||||
|
import org.apache.hadoop.fs.Path
|
||||||
|
import org.json4s._
|
||||||
|
import org.json4s.JsonDSL._
|
||||||
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
|
import org.apache.spark.ml.feature.RFormula
|
||||||
|
import org.apache.spark.ml.r.RWrapperUtils._
|
||||||
|
import org.apache.spark.ml.regression.{FMRegressionModel, FMRegressor}
|
||||||
|
import org.apache.spark.ml.util._
|
||||||
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
|
|
||||||
|
private[r] class FMRegressorWrapper private (
|
||||||
|
val pipeline: PipelineModel,
|
||||||
|
val features: Array[String]) extends MLWritable {
|
||||||
|
|
||||||
|
private val fmRegressionModel: FMRegressionModel =
|
||||||
|
pipeline.stages(1).asInstanceOf[FMRegressionModel]
|
||||||
|
|
||||||
|
lazy val rFeatures: Array[String] = if (fmRegressionModel.getFitIntercept) {
|
||||||
|
Array("(Intercept)") ++ features
|
||||||
|
} else {
|
||||||
|
features
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy val rCoefficients: Array[Double] = if (fmRegressionModel.getFitIntercept) {
|
||||||
|
Array(fmRegressionModel.intercept) ++ fmRegressionModel.linear.toArray
|
||||||
|
} else {
|
||||||
|
fmRegressionModel.linear.toArray
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy val rFactors = fmRegressionModel.factors.toArray
|
||||||
|
|
||||||
|
lazy val numFeatures: Int = fmRegressionModel.numFeatures
|
||||||
|
|
||||||
|
lazy val factorSize: Int = fmRegressionModel.getFactorSize
|
||||||
|
|
||||||
|
def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
|
pipeline.transform(dataset)
|
||||||
|
.drop(fmRegressionModel.getFeaturesCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def write: MLWriter = new FMRegressorWrapper.FMRegressorWrapperWriter(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[r] object FMRegressorWrapper
|
||||||
|
extends MLReadable[FMRegressorWrapper] {
|
||||||
|
|
||||||
|
def fit( // scalastyle:ignore
|
||||||
|
data: DataFrame,
|
||||||
|
formula: String,
|
||||||
|
factorSize: Int,
|
||||||
|
fitLinear: Boolean,
|
||||||
|
regParam: Double,
|
||||||
|
miniBatchFraction: Double,
|
||||||
|
initStd: Double,
|
||||||
|
maxIter: Int,
|
||||||
|
stepSize: Double,
|
||||||
|
tol: Double,
|
||||||
|
solver: String,
|
||||||
|
seed: String,
|
||||||
|
stringIndexerOrderType: String): FMRegressorWrapper = {
|
||||||
|
|
||||||
|
val rFormula = new RFormula()
|
||||||
|
.setFormula(formula)
|
||||||
|
.setStringIndexerOrderType(stringIndexerOrderType)
|
||||||
|
checkDataColumns(rFormula, data)
|
||||||
|
val rFormulaModel = rFormula.fit(data)
|
||||||
|
|
||||||
|
val fitIntercept = rFormula.hasIntercept
|
||||||
|
|
||||||
|
// get feature names from output schema
|
||||||
|
val schema = rFormulaModel.transform(data).schema
|
||||||
|
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||||
|
.attributes.get
|
||||||
|
val features = featureAttrs.map(_.name.get)
|
||||||
|
|
||||||
|
// assemble and fit the pipeline
|
||||||
|
val fmr = new FMRegressor()
|
||||||
|
.setFactorSize(factorSize)
|
||||||
|
.setFitIntercept(fitIntercept)
|
||||||
|
.setFitLinear(fitLinear)
|
||||||
|
.setRegParam(regParam)
|
||||||
|
.setMiniBatchFraction(miniBatchFraction)
|
||||||
|
.setInitStd(initStd)
|
||||||
|
.setMaxIter(maxIter)
|
||||||
|
.setStepSize(stepSize)
|
||||||
|
.setTol(tol)
|
||||||
|
.setSolver(solver)
|
||||||
|
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||||
|
|
||||||
|
if (seed != null && seed.length > 0) {
|
||||||
|
fmr.setSeed(seed.toLong)
|
||||||
|
}
|
||||||
|
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(rFormulaModel, fmr))
|
||||||
|
.fit(data)
|
||||||
|
|
||||||
|
new FMRegressorWrapper(pipeline, features)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def read: MLReader[FMRegressorWrapper] = new FMRegressorWrapperReader
|
||||||
|
|
||||||
|
class FMRegressorWrapperWriter(instance: FMRegressorWrapper) extends MLWriter {
|
||||||
|
|
||||||
|
override protected def saveImpl(path: String): Unit = {
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
|
||||||
|
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||||
|
("features" -> instance.features.toSeq)
|
||||||
|
val rMetadataJson: String = compact(render(rMetadata))
|
||||||
|
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||||
|
|
||||||
|
instance.pipeline.save(pipelinePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class FMRegressorWrapperReader extends MLReader[FMRegressorWrapper] {
|
||||||
|
|
||||||
|
override def load(path: String): FMRegressorWrapper = {
|
||||||
|
implicit val format = DefaultFormats
|
||||||
|
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||||
|
val pipelinePath = new Path(path, "pipeline").toString
|
||||||
|
|
||||||
|
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||||
|
val rMetadata = parse(rMetadataStr)
|
||||||
|
val features = (rMetadata \ "features").extract[Array[String]]
|
||||||
|
|
||||||
|
val pipeline = PipelineModel.load(pipelinePath)
|
||||||
|
new FMRegressorWrapper(pipeline, features)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -78,6 +78,8 @@ private[r] object RWrappers extends MLReader[Object] {
|
||||||
FMClassifierWrapper.load(path)
|
FMClassifierWrapper.load(path)
|
||||||
case "org.apache.spark.ml.r.LinearRegressionWrapper" =>
|
case "org.apache.spark.ml.r.LinearRegressionWrapper" =>
|
||||||
LinearRegressionWrapper.load(path)
|
LinearRegressionWrapper.load(path)
|
||||||
|
case "org.apache.spark.ml.r.FMRegressorWrapper" =>
|
||||||
|
FMRegressorWrapper.load(path)
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new SparkException(s"SparkR read.ml does not support load $className")
|
throw new SparkException(s"SparkR read.ml does not support load $className")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue