2017-01-08 04:10:36 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
library(testthat)
|
|
|
|
|
|
|
|
context("MLlib tree-based algorithms")
|
|
|
|
|
|
|
|
# Tests for MLlib tree-based algorithms in SparkR
|
2017-05-12 02:10:04 -04:00
|
|
|
sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE)
|
2017-01-08 04:10:36 -05:00
|
|
|
|
|
|
|
absoluteSparkPath <- function(x) {
|
|
|
|
sparkHome <- sparkR.conf("spark.home")
|
|
|
|
file.path(sparkHome, x)
|
|
|
|
}
|
|
|
|
|
|
|
|
test_that("spark.gbt", {
|
|
|
|
# regression
|
|
|
|
data <- suppressWarnings(createDataFrame(longley))
|
|
|
|
model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123)
|
|
|
|
predictions <- collect(predict(model, data))
|
|
|
|
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
|
|
|
|
63.221, 63.639, 64.989, 63.761,
|
|
|
|
66.019, 67.857, 68.169, 66.513,
|
|
|
|
68.655, 69.564, 69.331, 70.551),
|
|
|
|
tolerance = 1e-4)
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_equal(stats$formula, "Employed ~ .")
|
|
|
|
expect_equal(stats$numFeatures, 6)
|
|
|
|
expect_equal(length(stats$treeWeights), 20)
|
|
|
|
|
|
|
|
modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
|
|
|
|
write.ml(model, modelPath)
|
|
|
|
expect_error(write.ml(model, modelPath))
|
|
|
|
write.ml(model, modelPath, overwrite = TRUE)
|
|
|
|
model2 <- read.ml(modelPath)
|
|
|
|
stats2 <- summary(model2)
|
|
|
|
expect_equal(stats$formula, stats2$formula)
|
|
|
|
expect_equal(stats$numFeatures, stats2$numFeatures)
|
|
|
|
expect_equal(stats$features, stats2$features)
|
|
|
|
expect_equal(stats$featureImportances, stats2$featureImportances)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, stats2$maxDepth)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_equal(stats$numTrees, stats2$numTrees)
|
|
|
|
expect_equal(stats$treeWeights, stats2$treeWeights)
|
|
|
|
|
|
|
|
unlink(modelPath)
|
|
|
|
|
|
|
|
# classification
|
|
|
|
# label must be binary - GBTClassifier currently only supports binary classification.
|
|
|
|
iris2 <- iris[iris$Species != "virginica", ]
|
|
|
|
data <- suppressWarnings(createDataFrame(iris2))
|
|
|
|
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numFeatures, 2)
|
|
|
|
expect_equal(stats$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_error(capture.output(stats), NA)
|
|
|
|
expect_true(length(capture.output(stats)) > 6)
|
|
|
|
predictions <- collect(predict(model, data))$prediction
|
|
|
|
# test string prediction values
|
|
|
|
expect_equal(length(grep("setosa", predictions)), 50)
|
|
|
|
expect_equal(length(grep("versicolor", predictions)), 50)
|
|
|
|
|
|
|
|
modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
|
|
|
|
write.ml(model, modelPath)
|
|
|
|
expect_error(write.ml(model, modelPath))
|
|
|
|
write.ml(model, modelPath, overwrite = TRUE)
|
|
|
|
model2 <- read.ml(modelPath)
|
|
|
|
stats2 <- summary(model2)
|
|
|
|
expect_equal(stats$depth, stats2$depth)
|
|
|
|
expect_equal(stats$numNodes, stats2$numNodes)
|
|
|
|
expect_equal(stats$numClasses, stats2$numClasses)
|
|
|
|
|
|
|
|
unlink(modelPath)
|
|
|
|
|
|
|
|
iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
|
|
|
|
df <- suppressWarnings(createDataFrame(iris2))
|
|
|
|
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
|
|
|
|
s <- summary(m)
|
|
|
|
# test numeric prediction values
|
|
|
|
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
|
|
|
|
expect_equal(s$numFeatures, 5)
|
|
|
|
expect_equal(s$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
|
|
|
|
# spark.gbt classification can work on libsvm data
|
|
|
|
data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
|
|
|
|
source = "libsvm")
|
|
|
|
model <- spark.gbt(data, label ~ features, "classification")
|
|
|
|
expect_equal(summary(model)$numFeatures, 692)
|
|
|
|
})
|
|
|
|
|
|
|
|
test_that("spark.randomForest", {
|
|
|
|
# regression
|
|
|
|
data <- suppressWarnings(createDataFrame(longley))
|
|
|
|
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
|
|
|
|
numTrees = 1)
|
|
|
|
|
|
|
|
predictions <- collect(predict(model, data))
|
|
|
|
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
|
|
|
|
63.221, 63.639, 64.989, 63.761,
|
|
|
|
66.019, 67.857, 68.169, 66.513,
|
|
|
|
68.655, 69.564, 69.331, 70.551),
|
|
|
|
tolerance = 1e-4)
|
|
|
|
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numTrees, 1)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_error(capture.output(stats), NA)
|
|
|
|
expect_true(length(capture.output(stats)) > 6)
|
|
|
|
|
|
|
|
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
|
|
|
|
numTrees = 20, seed = 123)
|
|
|
|
predictions <- collect(predict(model, data))
|
|
|
|
expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
|
|
|
|
63.53160, 64.05470, 65.12710, 64.30450,
|
|
|
|
66.70910, 67.86125, 68.08700, 67.21865,
|
|
|
|
68.89275, 69.53180, 69.39640, 69.68250),
|
|
|
|
tolerance = 1e-4)
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
|
|
|
|
modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
|
|
|
|
write.ml(model, modelPath)
|
|
|
|
expect_error(write.ml(model, modelPath))
|
|
|
|
write.ml(model, modelPath, overwrite = TRUE)
|
|
|
|
model2 <- read.ml(modelPath)
|
|
|
|
stats2 <- summary(model2)
|
|
|
|
expect_equal(stats$formula, stats2$formula)
|
|
|
|
expect_equal(stats$numFeatures, stats2$numFeatures)
|
|
|
|
expect_equal(stats$features, stats2$features)
|
|
|
|
expect_equal(stats$featureImportances, stats2$featureImportances)
|
|
|
|
expect_equal(stats$numTrees, stats2$numTrees)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, stats2$maxDepth)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_equal(stats$treeWeights, stats2$treeWeights)
|
|
|
|
|
|
|
|
unlink(modelPath)
|
|
|
|
|
|
|
|
# classification
|
|
|
|
data <- suppressWarnings(createDataFrame(iris))
|
|
|
|
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
|
|
|
|
maxDepth = 5, maxBins = 16)
|
|
|
|
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numFeatures, 2)
|
|
|
|
expect_equal(stats$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
2017-01-08 04:10:36 -05:00
|
|
|
expect_error(capture.output(stats), NA)
|
|
|
|
expect_true(length(capture.output(stats)) > 6)
|
|
|
|
# Test string prediction values
|
|
|
|
predictions <- collect(predict(model, data))$prediction
|
|
|
|
expect_equal(length(grep("setosa", predictions)), 50)
|
|
|
|
expect_equal(length(grep("versicolor", predictions)), 50)
|
|
|
|
|
|
|
|
modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
|
|
|
|
write.ml(model, modelPath)
|
|
|
|
expect_error(write.ml(model, modelPath))
|
|
|
|
write.ml(model, modelPath, overwrite = TRUE)
|
|
|
|
model2 <- read.ml(modelPath)
|
|
|
|
stats2 <- summary(model2)
|
|
|
|
expect_equal(stats$depth, stats2$depth)
|
|
|
|
expect_equal(stats$numNodes, stats2$numNodes)
|
|
|
|
expect_equal(stats$numClasses, stats2$numClasses)
|
|
|
|
|
|
|
|
unlink(modelPath)
|
|
|
|
|
|
|
|
# Test numeric response variable
|
|
|
|
labelToIndex <- function(species) {
|
|
|
|
switch(as.character(species),
|
|
|
|
setosa = 0.0,
|
|
|
|
versicolor = 1.0,
|
|
|
|
virginica = 2.0
|
|
|
|
)
|
|
|
|
}
|
|
|
|
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
|
|
|
|
data <- suppressWarnings(createDataFrame(iris[-5]))
|
|
|
|
model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
|
|
|
|
maxDepth = 5, maxBins = 16)
|
|
|
|
stats <- summary(model)
|
|
|
|
expect_equal(stats$numFeatures, 2)
|
|
|
|
expect_equal(stats$numTrees, 20)
|
2017-03-12 15:15:19 -04:00
|
|
|
expect_equal(stats$maxDepth, 5)
|
|
|
|
|
2017-01-08 04:10:36 -05:00
|
|
|
# Test numeric prediction values
|
|
|
|
predictions <- collect(predict(model, data))$prediction
|
|
|
|
expect_equal(length(grep("1.0", predictions)), 50)
|
|
|
|
expect_equal(length(grep("2.0", predictions)), 50)
|
|
|
|
|
|
|
|
# spark.randomForest classification can work on libsvm data
|
|
|
|
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
|
|
|
|
source = "libsvm")
|
|
|
|
model <- spark.randomForest(data, label ~ features, "classification")
|
|
|
|
expect_equal(summary(model)$numFeatures, 4)
|
|
|
|
})
|
|
|
|
|
|
|
|
sparkR.session.stop()
|