Yanbo Liang 75e05a5a96 [SPARK-12566][SPARK-14324][ML] GLM model family, link function support in SparkR:::glm
* SparkR glm supports families and link functions which match R's signature for family.
* SparkR glm API refactor. The comparative standard of the new API is R glm, so I only expose the arguments that R glm supports: ```formula, family, data, epsilon and maxit```.
* This PR is focus on glm() and predict(), summary statistics will be done in a separate PR after this get in.
* This PR depends on #12287 which make GLMs support link prediction at Scala side. After that merged, I will add more tests for predict() to this PR.

Unit tests.

Author: Yanbo Liang <>

Closes #12294 from yanboliang/spark-12566.
2016-04-12 10:51:09 -07:00

context("MLlib functions")
# Tests for MLlib functions in SparkR
sc <- sparkR.init()
sqlContext <- sparkRSQL.init(sc)
test_that("formula of glm", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
# dot minus and intercept vs native glm
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# feature interaction vs native glm
model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# glm should work with long formula
training <- suppressWarnings(createDataFrame(sqlContext, iris))
training$LongLongLongLongLongName <- training$Sepal_Width
training$VeryLongLongLongLonLongName <- training$Sepal_Length
training$AnotherLongLongLongLongName <- training$Species
model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName,
data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
test_that("glm and predict", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
# gaussian family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# poisson family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training,
family = poisson(link = identity))
prediction <- predict(model, training)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
vals <- collect(select(prediction, "prediction"))
rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
expect_equal(length(predict(lm(y ~ x))), 15)
test_that("kmeans", {
newIris <- iris
newIris$Species <- NULL
training <- suppressWarnings(createDataFrame(sqlContext, newIris))
# Cache the DataFrame here to work around the bug SPARK-13178.
take(training, 1)
model <- kmeans(x = training, centers = 2)
sample <- take(select(predict(model, training), "prediction"), 1)
expect_equal(typeof(sample$prediction), "integer")
expect_equal(sample$prediction, 1)
# Test stats::kmeans is working
statsModel <- kmeans(x = newIris, centers = 2)
expect_equal(sort(unique(statsModel$cluster)), c(1, 2))
# Test fitted works on KMeans
fitted.model <- fitted(model)
expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
# Test summary works on KMeans
summary.model <- summary(model)
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
test_that("naiveBayes", {
# R code to reproduce the result.
# We do not support instance weights yet. So we ignore the frequencies.
#' library(e1071)
#' t <-
#' t1 <- t[t$Freq > 0, -5]
#' m <- naiveBayes(Survived ~ ., data = t1)
#' m
#' predict(m, t1)
# -- output of 'm'
# A-priori probabilities:
# Y
# No Yes
# 0.4166667 0.5833333
# Conditional probabilities:
# Class
# Y 1st 2nd 3rd Crew
# No 0.2000000 0.2000000 0.4000000 0.2000000
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
# Sex
# Y Male Female
# No 0.5 0.5
# Yes 0.5 0.5
# Age
# Y Child Adult
# No 0.2000000 0.8000000
# Yes 0.4285714 0.5714286
# -- output of 'predict(m, t1)'
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
t <-
t1 <- t[t$Freq > 0, -5]
df <- suppressWarnings(createDataFrame(sqlContext, t1))
m <- naiveBayes(Survived ~ ., data = df)
s <- summary(m)
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
p <- collect(select(predict(m, df), "prediction"))
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
"Yes", "Yes", "No", "No"))
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
test_that("survreg", {
# R code to reproduce the result.
#' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
#' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
#' library(survival)
#' model <- survreg(Surv(time, status) ~ x + sex, rData)
#' summary(model)
#' predict(model, data)
# -- output of 'summary(model)'
# Value Std. Error z p
# (Intercept) 1.315 0.270 4.88 1.07e-06
# x -0.190 0.173 -1.10 2.72e-01
# sex -0.253 0.329 -0.77 4.42e-01
# Log(scale) -1.160 0.396 -2.93 3.41e-03
# -- output of 'predict(model, data)'
# 1 2 3 4 5 6 7
# 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
model <- survreg(Surv(time, status) ~ x + sex, df)
stats <- summary(model)
coefs <- as.vector(stats$coefficients[, 1])
rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
expect_equal(coefs, rCoefs, tolerance = 1e-4)
rownames(stats$coefficients) ==
c("(Intercept)", "x", "sex", "Log(scale)")))
p <- collect(select(predict(model, df), "prediction"))
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
2.390146, 2.891269, 2.891269), tolerance = 1e-4)
# Test survival::survreg
if (requireNamespace("survival", quietly = TRUE)) {
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)