[SPARK-21381][SPARKR] SparkR: pass on setHandleInvalid for classification algorithms
## What changes were proposed in this pull request? SPARK-20307 Added handleInvalid option to RFormula for tree-based classification algorithms. We should add this parameter for other classification algorithms in SparkR. This is a followup PR for SPARK-20307. ## How was this patch tested? New Unit tests are added. Author: wangmiao1981 <wm624@hotmail.com> Closes #18605 from wangmiao1981/class.
This commit is contained in:
parent
6b186c9d60
commit
9570e81aa9
|
@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
|
|||
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
|
||||
#' or the number of partitions are large, this param could be adjusted to a larger size.
|
||||
#' This is an expert parameter. Default value should be good for most cases.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @return \code{spark.svmLinear} returns a fitted linear SVM model.
|
||||
#' @rdname spark.svmLinear
|
||||
|
@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
|
|||
#' @note spark.svmLinear since 2.2.0
|
||||
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
|
||||
threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
|
||||
threshold = 0.0, weightCol = NULL, aggregationDepth = 2,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
|
||||
if (!is.null(weightCol) && weightCol == "") {
|
||||
|
@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
|
|||
weightCol <- as.character(weightCol)
|
||||
}
|
||||
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
|
||||
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
|
||||
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
|
||||
weightCol, as.integer(aggregationDepth))
|
||||
weightCol, as.integer(aggregationDepth), handleInvalid)
|
||||
new("LinearSVCModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) {
|
|||
#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization.
|
||||
#' The bound vector size must be equal to 1 for binomial regression, or the number
|
||||
#' of classes for multinomial regression.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @return \code{spark.logit} returns a fitted logistic regression model.
|
||||
#' @rdname spark.logit
|
||||
|
@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
|
|||
tol = 1E-6, family = "auto", standardization = TRUE,
|
||||
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2,
|
||||
lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL,
|
||||
lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) {
|
||||
lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
row <- 0
|
||||
col <- 0
|
||||
|
@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
|
|||
upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients))
|
||||
}
|
||||
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
|
||||
data@sdf, formula, as.numeric(regParam),
|
||||
as.numeric(elasticNetParam), as.integer(maxIter),
|
||||
|
@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
|
|||
weightCol, as.integer(aggregationDepth),
|
||||
as.integer(row), as.integer(col),
|
||||
lowerBoundsOnCoefficients, upperBoundsOnCoefficients,
|
||||
lowerBoundsOnIntercepts, upperBoundsOnIntercepts)
|
||||
lowerBoundsOnIntercepts, upperBoundsOnIntercepts,
|
||||
handleInvalid)
|
||||
new("LogisticRegressionModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
|
|||
#' @param stepSize stepSize parameter.
|
||||
#' @param seed seed parameter for weights initialization.
|
||||
#' @param initialWeights initialWeights parameter for weights initialization, it should be a
|
||||
#' numeric vector.
|
||||
#' numeric vector.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
|
||||
#' @rdname spark.mlp
|
||||
|
@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char
|
|||
#' @note spark.mlp since 2.1.0
|
||||
setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100,
|
||||
tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) {
|
||||
tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
if (is.null(layers)) {
|
||||
stop ("layers must be a integer vector with length > 1.")
|
||||
|
@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
if (!is.null(initialWeights)) {
|
||||
initialWeights <- as.array(as.numeric(na.omit(initialWeights)))
|
||||
}
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper",
|
||||
"fit", data@sdf, formula, as.integer(blockSize), as.array(layers),
|
||||
as.character(solver), as.integer(maxIter), as.numeric(tol),
|
||||
as.numeric(stepSize), seed, initialWeights)
|
||||
as.numeric(stepSize), seed, initialWeights, handleInvalid)
|
||||
new("MultilayerPerceptronClassificationModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
|
|||
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
|
||||
#' operators are supported, including '~', '.', ':', '+', and '-'.
|
||||
#' @param smoothing smoothing parameter.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}.
|
||||
#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model.
|
||||
#' @rdname spark.naiveBayes
|
||||
|
@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode
|
|||
#' }
|
||||
#' @note spark.naiveBayes since 2.0.0
|
||||
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
|
||||
function(data, formula, smoothing = 1.0) {
|
||||
function(data, formula, smoothing = 1.0,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
|
||||
formula, data@sdf, smoothing)
|
||||
formula, data@sdf, smoothing, handleInvalid)
|
||||
new("NaiveBayesModel", jobj = jobj)
|
||||
})
|
||||
|
||||
|
|
|
@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) {
|
|||
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
|
||||
#' can speed up training of deeper trees. Users can set how often should the
|
||||
#' cache be checkpointed or disable it by setting checkpointInterval.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type in classification model.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @aliases spark.gbt,SparkDataFrame,formula-method
|
||||
#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model.
|
||||
|
@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
function(data, formula, type = c("regression", "classification"),
|
||||
maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL,
|
||||
seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0,
|
||||
checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) {
|
||||
checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
type <- match.arg(type)
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
if (!is.null(seed)) {
|
||||
|
@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
new("GBTRegressionModel", jobj = jobj)
|
||||
},
|
||||
classification = {
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
if (is.null(lossType)) lossType <- "logistic"
|
||||
lossType <- match.arg(lossType, "logistic")
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper",
|
||||
|
@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
|
|||
as.numeric(stepSize), as.integer(minInstancesPerNode),
|
||||
as.numeric(minInfoGain), as.integer(checkpointInterval),
|
||||
lossType, seed, as.numeric(subsamplingRate),
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
|
||||
handleInvalid)
|
||||
new("GBTClassificationModel", jobj = jobj)
|
||||
}
|
||||
)
|
||||
|
@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
|
|||
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
|
||||
#' can speed up training of deeper trees. Users can set how often should the
|
||||
#' cache be checkpointed or disable it by setting checkpointInterval.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type in classification model.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @aliases spark.randomForest,SparkDataFrame,formula-method
|
||||
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
|
||||
|
@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
|
|||
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
|
||||
#' can speed up training of deeper trees. Users can set how often should the
|
||||
#' cache be checkpointed or disable it by setting checkpointInterval.
|
||||
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label
|
||||
#' column of string type in classification model.
|
||||
#' Supported options: "skip" (filter out rows with invalid data),
|
||||
#' "error" (throw an error), "keep" (put invalid data in a special additional
|
||||
#' bucket, at index numLabels). Default is "error".
|
||||
#' @param ... additional arguments passed to the method.
|
||||
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
|
||||
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
|
||||
|
@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
|
|||
function(data, formula, type = c("regression", "classification"),
|
||||
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
|
||||
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
|
||||
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
|
||||
maxMemoryInMB = 256, cacheNodeIds = FALSE,
|
||||
handleInvalid = c("error", "keep", "skip")) {
|
||||
type <- match.arg(type)
|
||||
formula <- paste(deparse(formula), collapse = "")
|
||||
if (!is.null(seed)) {
|
||||
|
@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
|
|||
new("DecisionTreeRegressionModel", jobj = jobj)
|
||||
},
|
||||
classification = {
|
||||
handleInvalid <- match.arg(handleInvalid)
|
||||
if (is.null(impurity)) impurity <- "gini"
|
||||
impurity <- match.arg(impurity, c("gini", "entropy"))
|
||||
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
|
||||
|
@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo
|
|||
as.integer(maxBins), impurity,
|
||||
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
|
||||
as.integer(checkpointInterval), seed,
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
|
||||
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
|
||||
handleInvalid)
|
||||
new("DecisionTreeClassificationModel", jobj = jobj)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -70,6 +70,20 @@ test_that("spark.svmLinear", {
|
|||
prediction <- collect(select(predict(model, df), "prediction"))
|
||||
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1)
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "list")
|
||||
|
||||
})
|
||||
|
||||
test_that("spark.logit", {
|
||||
|
@ -263,6 +277,21 @@ test_that("spark.logit", {
|
|||
virginicaCoefs <- summary$coefficients[, "virginica"]
|
||||
expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1))
|
||||
expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1))
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.logit(traindf, clicked ~ ., regParam = 0.5)
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "character")
|
||||
|
||||
})
|
||||
|
||||
test_that("spark.mlp", {
|
||||
|
@ -344,6 +373,21 @@ test_that("spark.mlp", {
|
|||
expect_equal(summary$numOfOutputs, 3)
|
||||
expect_equal(summary$layers, c(4, 3))
|
||||
expect_equal(length(summary$weights), 15)
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3))
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "list")
|
||||
|
||||
})
|
||||
|
||||
test_that("spark.naiveBayes", {
|
||||
|
@ -427,6 +471,20 @@ test_that("spark.naiveBayes", {
|
|||
expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
|
||||
expect_equal(sum(s$apriori), 1)
|
||||
expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0)
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "character")
|
||||
})
|
||||
|
||||
sparkR.session.stop()
|
||||
|
|
|
@ -109,6 +109,20 @@ test_that("spark.gbt", {
|
|||
model <- spark.gbt(data, label ~ features, "classification")
|
||||
expect_equal(summary(model)$numFeatures, 692)
|
||||
}
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.gbt(traindf, clicked ~ ., type = "classification")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "character")
|
||||
})
|
||||
|
||||
test_that("spark.randomForest", {
|
||||
|
@ -328,6 +342,22 @@ test_that("spark.decisionTree", {
|
|||
model <- spark.decisionTree(data, label ~ features, "classification")
|
||||
expect_equal(summary(model)$numFeatures, 4)
|
||||
}
|
||||
|
||||
# Test unseen labels
|
||||
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
|
||||
someString = base::sample(c("this", "that"), 10, replace = TRUE),
|
||||
stringsAsFactors = FALSE)
|
||||
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
|
||||
traindf <- as.DataFrame(data[trainidxs, ])
|
||||
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
|
||||
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
|
||||
maxDepth = 5, maxBins = 16)
|
||||
predictions <- predict(model, testdf)
|
||||
expect_error(collect(predictions))
|
||||
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
|
||||
maxDepth = 5, maxBins = 16, handleInvalid = "keep")
|
||||
predictions <- predict(model, testdf)
|
||||
expect_equal(class(collect(predictions)$clicked[1]), "character")
|
||||
})
|
||||
|
||||
sparkR.session.stop()
|
||||
|
|
|
@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
|
|||
checkpointInterval: Int,
|
||||
seed: String,
|
||||
maxMemoryInMB: Int,
|
||||
cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
|
||||
cacheNodeIds: Boolean,
|
||||
handleInvalid: String): DecisionTreeClassifierWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
|
|
|
@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
|
|||
seed: String,
|
||||
subsamplingRate: Double,
|
||||
maxMemoryInMB: Int,
|
||||
cacheNodeIds: Boolean): GBTClassifierWrapper = {
|
||||
cacheNodeIds: Boolean,
|
||||
handleInvalid: String): GBTClassifierWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
|
|
|
@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper
|
|||
standardization: Boolean,
|
||||
threshold: Double,
|
||||
weightCol: String,
|
||||
aggregationDepth: Int
|
||||
aggregationDepth: Int,
|
||||
handleInvalid: String
|
||||
): LinearSVCWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
|
|
|
@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper
|
|||
lowerBoundsOnCoefficients: Array[Double],
|
||||
upperBoundsOnCoefficients: Array[Double],
|
||||
lowerBoundsOnIntercepts: Array[Double],
|
||||
upperBoundsOnIntercepts: Array[Double]
|
||||
upperBoundsOnIntercepts: Array[Double],
|
||||
handleInvalid: String
|
||||
): LogisticRegressionWrapper = {
|
||||
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper
|
|||
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
|
||||
val PREDICTED_LABEL_COL = "prediction"
|
||||
|
||||
def fit(
|
||||
def fit( // scalastyle:ignore
|
||||
data: DataFrame,
|
||||
formula: String,
|
||||
blockSize: Int,
|
||||
|
@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper
|
|||
tol: Double,
|
||||
stepSize: Double,
|
||||
seed: String,
|
||||
initialWeights: Array[Double]
|
||||
initialWeights: Array[Double],
|
||||
handleInvalid: String
|
||||
): MultilayerPerceptronClassifierWrapper = {
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
// get labels and feature names from output schema
|
||||
|
|
|
@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
|
|||
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
|
||||
val PREDICTED_LABEL_COL = "prediction"
|
||||
|
||||
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
|
||||
def fit(
|
||||
formula: String,
|
||||
data: DataFrame,
|
||||
smoothing: Double,
|
||||
handleInvalid: String): NaiveBayesWrapper = {
|
||||
val rFormula = new RFormula()
|
||||
.setFormula(formula)
|
||||
.setForceIndexLabel(true)
|
||||
.setHandleInvalid(handleInvalid)
|
||||
checkDataColumns(rFormula, data)
|
||||
val rFormulaModel = rFormula.fit(data)
|
||||
// get labels and feature names from output schema
|
||||
|
|
Loading…
Reference in a new issue