[SPARK-21381][SPARKR] SparkR: pass on setHandleInvalid for classification algorithms

## What changes were proposed in this pull request?

SPARK-20307 Added handleInvalid option to RFormula for tree-based classification algorithms. We should add this parameter for other classification algorithms in SparkR.

This is a followup PR for SPARK-20307.

## How was this patch tested?

New Unit tests are added.

Author: wangmiao1981 <wm624@hotmail.com>

Closes #18605 from wangmiao1981/class.
This commit is contained in:
wangmiao1981 2017-07-31 20:37:06 -07:00 committed by Felix Cheung
parent 6b186c9d60
commit 9570e81aa9
10 changed files with 175 additions and 24 deletions

View file

@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.svmLinear} returns a fitted linear SVM model.
#' @rdname spark.svmLinear
@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @note spark.svmLinear since 2.2.0
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
threshold = 0.0, weightCol = NULL, aggregationDepth = 2,
handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
if (!is.null(weightCol) && weightCol == "") {
@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
weightCol <- as.character(weightCol)
}
handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
weightCol, as.integer(aggregationDepth))
weightCol, as.integer(aggregationDepth), handleInvalid)
new("LinearSVCModel", jobj = jobj)
})
@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) {
#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization.
#' The bound vector size must be equal to 1 for binomial regression, or the number
#' of classes for multinomial regression.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
tol = 1E-6, family = "auto", standardization = TRUE,
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL,
lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) {
lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL,
handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
row <- 0
col <- 0
@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients))
}
handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
data@sdf, formula, as.numeric(regParam),
as.numeric(elasticNetParam), as.integer(maxIter),
@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
weightCol, as.integer(aggregationDepth),
as.integer(row), as.integer(col),
lowerBoundsOnCoefficients, upperBoundsOnCoefficients,
lowerBoundsOnIntercepts, upperBoundsOnIntercepts)
lowerBoundsOnIntercepts, upperBoundsOnIntercepts,
handleInvalid)
new("LogisticRegressionModel", jobj = jobj)
})
@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
#' @param stepSize stepSize parameter.
#' @param seed seed parameter for weights initialization.
#' @param initialWeights initialWeights parameter for weights initialization, it should be a
#' numeric vector.
#' numeric vector.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
#' @rdname spark.mlp
@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
#' @note spark.mlp since 2.1.0
setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100,
tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) {
tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL,
handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
if (is.null(layers)) {
stop ("layers must be a integer vector with length > 1.")
@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
if (!is.null(initialWeights)) {
initialWeights <- as.array(as.numeric(na.omit(initialWeights)))
}
handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper",
"fit", data@sdf, formula, as.integer(blockSize), as.array(layers),
as.character(solver), as.integer(maxIter), as.numeric(tol),
as.numeric(stepSize), seed, initialWeights)
as.numeric(stepSize), seed, initialWeights, handleInvalid)
new("MultilayerPerceptronClassificationModel", jobj = jobj)
})
@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param smoothing smoothing parameter.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
#' @rdname spark.naiveBayes
@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
#' }
#' @note spark.naiveBayes since 2.0.0
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, smoothing = 1.0) {
function(data, formula, smoothing = 1.0,
handleInvalid = c("error", "keep", "skip")) {
formula <- paste(deparse(formula), collapse = "")
handleInvalid <- match.arg(handleInvalid)
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, smoothing)
formula, data@sdf, smoothing, handleInvalid)
new("NaiveBayesModel", jobj = jobj)
})

View file

@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) {
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type in classification model.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.gbt,SparkDataFrame,formula-method
#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL,
seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0,
checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) {
checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE,
handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
new("GBTRegressionModel", jobj = jobj)
},
classification = {
handleInvalid <- match.arg(handleInvalid)
if (is.null(lossType)) lossType <- "logistic"
lossType <- match.arg(lossType, "logistic")
jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
as.numeric(stepSize), as.integer(minInstancesPerNode),
as.numeric(minInfoGain), as.integer(checkpointInterval),
lossType, seed, as.numeric(subsamplingRate),
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
handleInvalid)
new("GBTClassificationModel", jobj = jobj)
}
)
@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type in classification model.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
#' column of string type in classification model.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
maxMemoryInMB = 256, cacheNodeIds = FALSE,
handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
handleInvalid <- match.arg(handleInvalid)
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
handleInvalid)
new("DecisionTreeClassificationModel", jobj = jobj)
}
)

View file

@ -70,6 +70,20 @@ test_that("spark.svmLinear", {
prediction <- collect(select(predict(model, df), "prediction"))
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "list")
})
test_that("spark.logit", {
@ -263,6 +277,21 @@ test_that("spark.logit", {
virginicaCoefs <- summary$coefficients[, "virginica"]
expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.logit(traindf, clicked ~ ., regParam = 0.5)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
test_that("spark.mlp", {
@ -344,6 +373,21 @@ test_that("spark.mlp", {
expect_equal(summary$numOfOutputs, 3)
expect_equal(summary$layers, c(4, 3))
expect_equal(length(summary$weights), 15)
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3))
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "list")
})
test_that("spark.naiveBayes", {
@ -427,6 +471,20 @@ test_that("spark.naiveBayes", {
expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
sparkR.session.stop()

View file

@ -109,6 +109,20 @@ test_that("spark.gbt", {
model <- spark.gbt(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 692)
}
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.gbt(traindf, clicked ~ ., type = "classification")
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
test_that("spark.randomForest", {
@ -328,6 +342,22 @@ test_that("spark.decisionTree", {
model <- spark.decisionTree(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 4)
}
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
maxDepth = 5, maxBins = 16)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
maxDepth = 5, maxBins = 16, handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
sparkR.session.stop()

View file

@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
checkpointInterval: Int,
seed: String,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
cacheNodeIds: Boolean,
handleInvalid: String): DecisionTreeClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

View file

@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
seed: String,
subsamplingRate: Double,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): GBTClassifierWrapper = {
cacheNodeIds: Boolean,
handleInvalid: String): GBTClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

View file

@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper
standardization: Boolean,
threshold: Double,
weightCol: String,
aggregationDepth: Int
aggregationDepth: Int,
handleInvalid: String
): LinearSVCWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

View file

@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper
lowerBoundsOnCoefficients: Array[Double],
upperBoundsOnCoefficients: Array[Double],
lowerBoundsOnIntercepts: Array[Double],
upperBoundsOnIntercepts: Array[Double]
upperBoundsOnIntercepts: Array[Double],
handleInvalid: String
): LogisticRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)

View file

@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
def fit(
def fit( // scalastyle:ignore
data: DataFrame,
formula: String,
blockSize: Int,
@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper
tol: Double,
stepSize: Double,
seed: String,
initialWeights: Array[Double]
initialWeights: Array[Double],
handleInvalid: String
): MultilayerPerceptronClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema

View file

@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
def fit(
formula: String,
data: DataFrame,
smoothing: Double,
handleInvalid: String): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema