[SPARK-30819][SPARKR][ML] Add FMRegressor wrapper to SparkR
### What changes were proposed in this pull request? This pull request adds SparkR wrapper for `FMRegressor`: - Supporting ` org.apache.spark.ml.r.FMRegressorWrapper`. - `FMRegressionModel` S4 class. - Corresponding `spark.fmRegressor`, `predict`, `summary` and `write.ml` generics. - Corresponding docs and tests. ### Why are the changes needed? Feature parity. ### Does this PR introduce any user-facing change? No (new API). ### How was this patch tested? New unit tests. Closes #27571 from zero323/SPARK-30819. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
61f903fa7a
commit
697fe911ac
|
@ -74,7 +74,8 @@ exportMethods("glm",
|
|||
"spark.findFrequentSequentialPatterns",
|
||||
"spark.assignClusters",
|
||||
"spark.fmClassifier",
|
||||
"spark.lm")
|
||||
"spark.lm",
|
||||
"spark.fmRegressor")
|
||||
|
||||
# Job group lifecycle management methods
|
||||
export("setJobGroup",
|
||||
|
|
|
@ -1483,6 +1483,10 @@ setGeneric("spark.bisectingKmeans",
|
|||
setGeneric("spark.fmClassifier",
|
||||
function(data, formula, ...) { standardGeneric("spark.fmClassifier") })
|
||||
|
||||
#' @rdname spark.fmRegressor
|
||||
setGeneric("spark.fmRegressor",
|
||||
function(data, formula, ...) { standardGeneric("spark.fmRegressor") })
|
||||
|
||||
#' @rdname spark.gaussianMixture
|
||||
setGeneric("spark.gaussianMixture",
|
||||
function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
|
||||
|
|
|
@ -42,6 +42,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
|
|||
#' @note LinearRegressionModel since 3.1.0
|
||||
setClass("LinearRegressionModel", representation(jobj = "jobj"))
|
||||
|
||||
#' S4 class that represents a FMRegressionModel
|
||||
#'
|
||||
#' @param jobj a Java object reference to the backing Scala FMRegressorWrapper
|
||||
#' @note FMRegressionModel since 3.1.0
|
||||
setClass("FMRegressionModel", representation(jobj = "jobj"))
|
||||
|
||||
#' Generalized Linear Models
|
||||
#'
|
||||
#' Fits generalized linear model against a SparkDataFrame.
|
||||
|
@ -612,18 +618,22 @@ setMethod("spark.lm", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
|
||||
"alphabetDesc", "alphabetAsc")) {
|
||||
|
||||
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
|
||||
|
||||
solver <- match.arg(solver)
|
||||
loss <- match.arg(loss)
|
||||
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
|
||||
|
||||
|
||||
if (!is.null(weightCol) && weightCol == "") {
|
||||
weightCol <- NULL
|
||||
} else if (!is.null(weightCol)) {
|
||||
weightCol <- as.character(weightCol)
|
||||
}
|
||||
|
||||
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.LinearRegressionWrapper",
|
||||
"fit",
|
||||
data@sdf,
|
||||
|
@ -642,8 +652,10 @@ setMethod("spark.lm", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
new("LinearRegressionModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
||||
# Returns the summary of a Linear Regression model produced by \code{spark.lm}
|
||||
|
||||
|
||||
#' @param object a Linear Regression Model model fitted by \code{spark.lm}.
|
||||
#' @return \code{summary} returns summary information of the fitted model, which is a list.
|
||||
#'
|
||||
|
@ -659,14 +671,17 @@ setMethod("summary", signature(object = "LinearRegressionModel"),
|
|||
rownames(coefficients) <- unlist(features)
|
||||
numFeatures <- callJMethod(jobj, "numFeatures")
|
||||
|
||||
|
||||
list(
|
||||
coefficients = coefficients,
|
||||
numFeatures = numFeatures
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Predicted values based on an LinearRegressionModel model
|
||||
|
||||
|
||||
#' @param newData a SparkDataFrame for testing.
|
||||
#' @return \code{predict} returns the predicted values based on a LinearRegressionModel.
|
||||
#'
|
||||
|
@ -678,8 +693,10 @@ setMethod("predict", signature(object = "LinearRegressionModel"),
|
|||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
|
||||
# Save fitted LinearRegressionModel to the input path
|
||||
|
||||
|
||||
#' @param path The directory where the model is saved.
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
|
@ -691,3 +708,158 @@ setMethod("write.ml", signature(object = "LinearRegressionModel", path = "charac
|
|||
function(object, path, overwrite = FALSE) {
|
||||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
||||
#' Factorization Machines Regression Model
|
||||
#'
|
||||
#' \code{spark.fmRegressor} fits a factorization regression model against a SparkDataFrame.
|
||||
#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make
|
||||
#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
|
||||
#'
|
||||
#' @param data a \code{SparkDataFrame} of observations and labels for model fitting.
|
||||
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
|
||||
#' operators are supported, including '~', '.', ':', '+', and '-'.
|
||||
#' @param factorSize dimensionality of the factors.
|
||||
#' @param fitLinear whether to fit linear term. # TODO Can we express this with formula?
|
||||
#' @param regParam the regularization parameter.
|
||||
#' @param miniBatchFraction the mini-batch fraction parameter.
|
||||
#' @param initStd the standard deviation of initial coefficients.
|
||||
#' @param maxIter maximum iteration number.
|
||||
#' @param stepSize stepSize parameter.
|
||||
#' @param tol convergence tolerance of iterations.
|
||||
#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "adamW".
|
||||
#' @param seed seed parameter for weights initialization.
|
||||
#' @param stringIndexerOrderType how to order categories of a string feature column. This is used to
|
||||
#' decide the base level of a string feature as the last category
|
||||
#' after ordering is dropped when encoding strings. Supported options
|
||||
#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and
|
||||
#' "alphabetAsc". The default value is "frequencyDesc". When the
|
||||
#' ordering is set to "alphabetDesc", this drops the same category
|
||||
#' as R when encoding strings.
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @return \code{spark.fmRegressor} returns a fitted Factorization Machines Regression Model.
|
||||
#'
|
||||
#' @rdname spark.fmRegressor
|
||||
#' @aliases spark.fmRegressor,SparkDataFrame,formula-method
|
||||
#' @name spark.fmRegressor
|
||||
#' @seealso \link{read.ml}
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
|
||||
#'
|
||||
#' # fit Factorization Machines Regression Model
|
||||
#' model <- spark.fmRegressor(
|
||||
#' df, label ~ features,
|
||||
#' regParam = 0.01, maxIter = 10, fitLinear = TRUE
|
||||
#' )
|
||||
#'
|
||||
#' # get the summary of the model
|
||||
#' summary(model)
|
||||
#'
|
||||
#' # make predictions
|
||||
#' predictions <- predict(model, df)
|
||||
#'
|
||||
#' # save and load the model
|
||||
#' path <- "path/to/model"
|
||||
#' write.ml(model, path)
|
||||
#' savedModel <- read.ml(path)
|
||||
#' summary(savedModel)
|
||||
#' }
|
||||
#' @note spark.fmRegressor since 3.1.0
|
||||
setMethod("spark.fmRegressor", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, factorSize = 8, fitLinear = TRUE, regParam = 0.0,
|
||||
miniBatchFraction = 1.0, initStd = 0.01, maxIter = 100, stepSize=1.0,
|
||||
tol = 1e-6, solver = c("adamW", "gd"), seed = NULL,
|
||||
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
|
||||
"alphabetDesc", "alphabetAsc")) {
|
||||
|
||||
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
|
||||
|
||||
if (!is.null(seed)) {
|
||||
seed <- as.character(as.integer(seed))
|
||||
}
|
||||
|
||||
|
||||
solver <- match.arg(solver)
|
||||
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
|
||||
|
||||
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.FMRegressorWrapper",
|
||||
"fit",
|
||||
data@sdf,
|
||||
formula,
|
||||
as.integer(factorSize),
|
||||
as.logical(fitLinear),
|
||||
as.numeric(regParam),
|
||||
as.numeric(miniBatchFraction),
|
||||
as.numeric(initStd),
|
||||
as.integer(maxIter),
|
||||
as.numeric(stepSize),
|
||||
as.numeric(tol),
|
||||
solver,
|
||||
seed,
|
||||
stringIndexerOrderType)
|
||||
new("FMRegressionModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
||||
# Returns the summary of a FM Regression model produced by \code{spark.fmRegressor}
|
||||
|
||||
|
||||
#' @param object a FM Regression Model model fitted by \code{spark.fmRegressor}.
|
||||
#' @return \code{summary} returns summary information of the fitted model, which is a list.
|
||||
#'
|
||||
#' @rdname spark.fmRegressor
|
||||
#' @note summary(FMRegressionModel) since 3.1.0
|
||||
setMethod("summary", signature(object = "FMRegressionModel"),
|
||||
function(object) {
|
||||
jobj <- object@jobj
|
||||
features <- callJMethod(jobj, "rFeatures")
|
||||
coefficients <- callJMethod(jobj, "rCoefficients")
|
||||
coefficients <- as.matrix(unlist(coefficients))
|
||||
colnames(coefficients) <- c("Estimate")
|
||||
rownames(coefficients) <- unlist(features)
|
||||
numFeatures <- callJMethod(jobj, "numFeatures")
|
||||
raw_factors <- unlist(callJMethod(jobj, "rFactors"))
|
||||
factor_size <- callJMethod(jobj, "factorSize")
|
||||
|
||||
|
||||
list(
|
||||
coefficients = coefficients,
|
||||
factors = matrix(raw_factors, ncol = factor_size),
|
||||
numFeatures = numFeatures,
|
||||
factorSize = factor_size
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
# Predicted values based on an FMRegressionModel model
|
||||
|
||||
|
||||
#' @param newData a SparkDataFrame for testing.
|
||||
#' @return \code{predict} returns the predicted values based on an FMRegressionModel.
|
||||
#'
|
||||
#' @rdname spark.fmRegressor
|
||||
#' @aliases predict,FMRegressionModel,SparkDataFrame-method
|
||||
#' @note predict(FMRegressionModel) since 3.1.0
|
||||
setMethod("predict", signature(object = "FMRegressionModel"),
|
||||
function(object, newData) {
|
||||
predict_internal(object, newData)
|
||||
})
|
||||
|
||||
|
||||
# Save fitted FMRegressionModel to the input path
|
||||
|
||||
|
||||
#' @param path The directory where the model is saved.
|
||||
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
|
||||
#' which means throw exception if the output path exists.
|
||||
#'
|
||||
#' @rdname spark.fmRegressor
|
||||
#' @aliases write.ml,FMRegressionModel,character-method
|
||||
#' @note write.ml(FMRegressionModel, character) since 3.1.0
|
||||
setMethod("write.ml", signature(object = "FMRegressionModel", path = "character"),
|
||||
function(object, path, overwrite = FALSE) {
|
||||
write_internal(object, path, overwrite)
|
||||
})
|
||||
|
|
|
@ -127,6 +127,8 @@ read.ml <- function(path) {
|
|||
new("FMClassificationModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearRegressionWrapper")) {
|
||||
new("LinearRegressionModel", jobj = jobj)
|
||||
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FMRegressorWrapper")) {
|
||||
new("FMRegressionModel", jobj = jobj)
|
||||
} else {
|
||||
stop("Unsupported model: ", jobj)
|
||||
}
|
||||
|
|
|
@ -580,4 +580,34 @@ test_that("spark.survreg", {
|
|||
})
|
||||
})
|
||||
|
||||
|
||||
test_that("spark.fmRegressor", {
|
||||
df <- suppressWarnings(createDataFrame(iris))
|
||||
|
||||
model <- spark.fmRegressor(
|
||||
df, Sepal_Width ~ .,
|
||||
regParam = 0.01, maxIter = 10, fitLinear = TRUE
|
||||
)
|
||||
|
||||
prediction1 <- predict(model, df)
|
||||
expect_is(prediction1, "SparkDataFrame")
|
||||
|
||||
# Test model save/load
|
||||
if (windows_with_hadoop()) {
|
||||
modelPath <- tempfile(pattern = "spark-fmregressor", fileext = ".tmp")
|
||||
write.ml(model, modelPath)
|
||||
model2 <- read.ml(modelPath)
|
||||
|
||||
expect_is(model2, "FMRegressionModel")
|
||||
expect_equal(summary(model), summary(model2))
|
||||
|
||||
prediction2 <- predict(model2, df)
|
||||
expect_equal(
|
||||
collect(prediction1),
|
||||
collect(prediction2)
|
||||
)
|
||||
unlink(modelPath)
|
||||
}
|
||||
})
|
||||
|
||||
sparkR.session.stop()
|
||||
|
|
|
@ -535,6 +535,8 @@ SparkR supports the following machine learning models and algorithms.
|
|||
|
||||
* Linear Regression
|
||||
|
||||
* Factorization Machines (FM) Regressor
|
||||
|
||||
#### Tree - Classification and Regression
|
||||
|
||||
* Decision Tree
|
||||
|
@ -847,6 +849,20 @@ predictions <- predict(model, carsDF)
|
|||
head(select(predictions, predictions$prediction))
|
||||
```
|
||||
|
||||
#### Factorization Machines Regressor
|
||||
|
||||
Factorization Machines for regression problems.
|
||||
|
||||
For background and details about the implementation of factorization machines,
|
||||
refer to the [Factorization Machines section](https://spark.apache.org/docs/latest/ml-classification-regression.html#factorization-machines).
|
||||
|
||||
```{r}
|
||||
model <- spark.fmRegressor(carsDF, mpg ~ wt + hp)
|
||||
summary(model)
|
||||
predictions <- predict(model, carsDF)
|
||||
head(select(predictions, predictions$prediction))
|
||||
```
|
||||
|
||||
#### Decision Tree
|
||||
|
||||
`spark.decisionTree` fits a [decision tree](https://en.wikipedia.org/wiki/Decision_tree_learning) classification or regression model on a `SparkDataFrame`.
|
||||
|
|
|
@ -1101,6 +1101,15 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.
|
|||
{% include_example python/ml/fm_regressor_example.py %}
|
||||
</div>
|
||||
|
||||
<div data-lang="r" markdown="1">
|
||||
|
||||
Refer to the [R API documentation](api/R/spark.fmRegressor.html) for more details.
|
||||
|
||||
Note: At the moment SparkR doesn't suport feature scaling.
|
||||
|
||||
{% include_example r/ml/fmRegressor.R %}
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
|
|
|
@ -548,6 +548,7 @@ SparkR supports the following machine learning algorithms currently:
|
|||
* [`spark.glm`](api/R/spark.glm.html) or [`glm`](api/R/glm.html): [`Generalized Linear Model (GLM)`](ml-classification-regression.html#generalized-linear-regression)
|
||||
* [`spark.isoreg`](api/R/spark.isoreg.html): [`Isotonic Regression`](ml-classification-regression.html#isotonic-regression)
|
||||
* [`spark.lm`](api/R/spark.lm.html): [`Linear Regression`](ml-classification-regression.html#linear-regression)
|
||||
* [`spark.fmRegressor`](api/R/spark.fmRegressor.html): [`Factorization Machines regressor`](ml-classification-regression.html#factorization-machines-regressor)
|
||||
|
||||
#### Tree
|
||||
|
||||
|
|
45
examples/src/main/r/ml/fmRegressor.R
Normal file
45
examples/src/main/r/ml/fmRegressor.R
Normal file
|
@ -0,0 +1,45 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# To run this example use
|
||||
# ./bin/spark-submit examples/src/main/r/ml/fmRegressor.R
|
||||
|
||||
# Load SparkR library into your R session
|
||||
library(SparkR)
|
||||
|
||||
# Initialize SparkSession
|
||||
sparkR.session(appName = "SparkR-ML-fmRegressor-example")
|
||||
|
||||
# $example on$
|
||||
# Load training data
|
||||
df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm")
|
||||
training_test <- randomSplit(df, c(0.7, 0.3))
|
||||
training <- training_test[[1]]
|
||||
test <- training_test[[2]]
|
||||
|
||||
# Fit a FM regression model
|
||||
model <- spark.fmRegressor(training, label ~ features)
|
||||
|
||||
# Model summary
|
||||
summary(model)
|
||||
|
||||
# Prediction
|
||||
predictions <- predict(model, test)
|
||||
head(predictions)
|
||||
# $example off$
|
||||
|
||||
sparkR.session.stop()
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.r
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.json4s._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods._
|
||||
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.feature.RFormula
|
||||
import org.apache.spark.ml.r.RWrapperUtils._
|
||||
import org.apache.spark.ml.regression.{FMRegressionModel, FMRegressor}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
|
||||
private[r] class FMRegressorWrapper private (
|
||||
val pipeline: PipelineModel,
|
||||
val features: Array[String]) extends MLWritable {
|
||||
|
||||
private val fmRegressionModel: FMRegressionModel =
|
||||
pipeline.stages(1).asInstanceOf[FMRegressionModel]
|
||||
|
||||
lazy val rFeatures: Array[String] = if (fmRegressionModel.getFitIntercept) {
|
||||
Array("(Intercept)") ++ features
|
||||
} else {
|
||||
features
|
||||
}
|
||||
|
||||
lazy val rCoefficients: Array[Double] = if (fmRegressionModel.getFitIntercept) {
|
||||
Array(fmRegressionModel.intercept) ++ fmRegressionModel.linear.toArray
|
||||
} else {
|
||||
fmRegressionModel.linear.toArray
|
||||
}
|
||||
|
||||
lazy val rFactors = fmRegressionModel.factors.toArray
|
||||
|
||||
lazy val numFeatures: Int = fmRegressionModel.numFeatures
|
||||
|
||||
lazy val factorSize: Int = fmRegressionModel.getFactorSize
|
||||
|
||||
def transform(dataset: Dataset[_]): DataFrame = {
|
||||
pipeline.transform(dataset)
|
||||
.drop(fmRegressionModel.getFeaturesCol)
|
||||
}
|
||||
|
||||
override def write: MLWriter = new FMRegressorWrapper.FMRegressorWrapperWriter(this)
|
||||
}
|
||||
|
||||
private[r] object FMRegressorWrapper
|
||||
extends MLReadable[FMRegressorWrapper] {
|
||||
|
||||
def fit( // scalastyle:ignore
|
||||
data: DataFrame,
|
||||
formula: String,
|
||||
factorSize: Int,
|
||||
fitLinear: Boolean,
|
||||
regParam: Double,
|
||||
miniBatchFraction: Double,
|
||||
initStd: Double,
|
||||
maxIter: Int,
|
||||
stepSize: Double,
|
||||
tol: Double,
|
||||
solver: String,
|
||||
seed: String,
|
||||
stringIndexerOrderType: String): FMRegressorWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setStringIndexerOrderType(stringIndexerOrderType)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
val fitIntercept = rFormula.hasIntercept
|
||||
|
||||
// get feature names from output schema
|
||||
val schema = rFormulaModel.transform(data).schema
|
||||
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
|
||||
.attributes.get
|
||||
val features = featureAttrs.map(_.name.get)
|
||||
|
||||
// assemble and fit the pipeline
|
||||
val fmr = new FMRegressor()
|
||||
.setFactorSize(factorSize)
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setFitLinear(fitLinear)
|
||||
.setRegParam(regParam)
|
||||
.setMiniBatchFraction(miniBatchFraction)
|
||||
.setInitStd(initStd)
|
||||
.setMaxIter(maxIter)
|
||||
.setStepSize(stepSize)
|
||||
.setTol(tol)
|
||||
.setSolver(solver)
|
||||
.setFeaturesCol(rFormula.getFeaturesCol)
|
||||
|
||||
if (seed != null && seed.length > 0) {
|
||||
fmr.setSeed(seed.toLong)
|
||||
}
|
||||
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(rFormulaModel, fmr))
|
||||
.fit(data)
|
||||
|
||||
new FMRegressorWrapper(pipeline, features)
|
||||
}
|
||||
|
||||
override def read: MLReader[FMRegressorWrapper] = new FMRegressorWrapperReader
|
||||
|
||||
class FMRegressorWrapperWriter(instance: FMRegressorWrapper) extends MLWriter {
|
||||
|
||||
override protected def saveImpl(path: String): Unit = {
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadata = ("class" -> instance.getClass.getName) ~
|
||||
("features" -> instance.features.toSeq)
|
||||
val rMetadataJson: String = compact(render(rMetadata))
|
||||
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
|
||||
|
||||
instance.pipeline.save(pipelinePath)
|
||||
}
|
||||
}
|
||||
|
||||
class FMRegressorWrapperReader extends MLReader[FMRegressorWrapper] {
|
||||
|
||||
override def load(path: String): FMRegressorWrapper = {
|
||||
implicit val format = DefaultFormats
|
||||
val rMetadataPath = new Path(path, "rMetadata").toString
|
||||
val pipelinePath = new Path(path, "pipeline").toString
|
||||
|
||||
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
|
||||
val rMetadata = parse(rMetadataStr)
|
||||
val features = (rMetadata \ "features").extract[Array[String]]
|
||||
|
||||
val pipeline = PipelineModel.load(pipelinePath)
|
||||
new FMRegressorWrapper(pipeline, features)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -78,6 +78,8 @@ private[r] object RWrappers extends MLReader[Object] {
|
|||
FMClassifierWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.LinearRegressionWrapper" =>
|
||||
LinearRegressionWrapper.load(path)
|
||||
case "org.apache.spark.ml.r.FMRegressorWrapper" =>
|
||||
FMRegressorWrapper.load(path)
|
||||
case _ =>
|
||||
throw new SparkException(s"SparkR read.ml does not support load $className")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue