[SPARK-16785] R dapply doesn't return array or raw columns
## What changes were proposed in this pull request? Fixed bug in `dapplyCollect` by changing the `compute` function of `worker.R` to explicitly handle raw (binary) vectors. cc shivaram ## How was this patch tested? Unit tests Author: Clark Fitzgerald <clarkfitzg@gmail.com> Closes #14783 from clarkfitzg/SPARK-16785.
This commit is contained in:
parent
eb1ab88a86
commit
9fccde4ff8
|
@ -202,7 +202,10 @@ getDefaultSqlSource <- function() {
|
|||
# TODO(davies): support sampling and infer type from NA
|
||||
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
|
||||
sparkSession <- getSparkSession()
|
||||
|
||||
if (is.data.frame(data)) {
|
||||
# Convert data into a list of rows. Each row is a list.
|
||||
|
||||
# get the names of columns, they will be put into RDD
|
||||
if (is.null(schema)) {
|
||||
schema <- names(data)
|
||||
|
@ -227,6 +230,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
|
|||
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
|
||||
data <- do.call(mapply, append(args, data))
|
||||
}
|
||||
|
||||
if (is.list(data)) {
|
||||
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
|
||||
rdd <- parallelize(sc, data)
|
||||
|
|
|
@ -697,3 +697,18 @@ isMasterLocal <- function(master) {
|
|||
isSparkRShell <- function() {
|
||||
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
|
||||
}
|
||||
|
||||
# rbind a list of rows with raw (binary) columns
|
||||
#
|
||||
# @param inputData a list of rows, with each row a list
|
||||
# @return data.frame with raw columns as lists
|
||||
rbindRaws <- function(inputData){
|
||||
row1 <- inputData[[1]]
|
||||
rawcolumns <- ("raw" == sapply(row1, class))
|
||||
|
||||
listmatrix <- do.call(rbind, inputData)
|
||||
# A dataframe with all list columns
|
||||
out <- as.data.frame(listmatrix)
|
||||
out[!rawcolumns] <- lapply(out[!rawcolumns], unlist)
|
||||
out
|
||||
}
|
||||
|
|
|
@ -2270,6 +2270,27 @@ test_that("dapply() and dapplyCollect() on a DataFrame", {
|
|||
expect_identical(expected, result)
|
||||
})
|
||||
|
||||
test_that("dapplyCollect() on DataFrame with a binary column", {
|
||||
|
||||
df <- data.frame(key = 1:3)
|
||||
df$bytes <- lapply(df$key, serialize, connection = NULL)
|
||||
|
||||
df_spark <- createDataFrame(df)
|
||||
|
||||
result1 <- collect(df_spark)
|
||||
expect_identical(df, result1)
|
||||
|
||||
result2 <- dapplyCollect(df_spark, function(x) x)
|
||||
expect_identical(df, result2)
|
||||
|
||||
# A data.frame with a single column of bytes
|
||||
scb <- subset(df, select = "bytes")
|
||||
scb_spark <- createDataFrame(scb)
|
||||
result <- dapplyCollect(scb_spark, function(x) x)
|
||||
expect_identical(scb, result)
|
||||
|
||||
})
|
||||
|
||||
test_that("repartition by columns on DataFrame", {
|
||||
df <- createDataFrame(
|
||||
list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)),
|
||||
|
|
|
@ -183,4 +183,28 @@ test_that("overrideEnvs", {
|
|||
expect_equal(config[["config_only"]], "ok")
|
||||
})
|
||||
|
||||
test_that("rbindRaws", {
|
||||
|
||||
# Mixed Column types
|
||||
r <- serialize(1:5, connection = NULL)
|
||||
r1 <- serialize(1, connection = NULL)
|
||||
r2 <- serialize(letters, connection = NULL)
|
||||
r3 <- serialize(1:10, connection = NULL)
|
||||
inputData <- list(list(1L, r1, "a", r), list(2L, r2, "b", r),
|
||||
list(3L, r3, "c", r))
|
||||
expected <- data.frame(V1 = 1:3)
|
||||
expected$V2 <- list(r1, r2, r3)
|
||||
expected$V3 <- c("a", "b", "c")
|
||||
expected$V4 <- list(r, r, r)
|
||||
result <- rbindRaws(inputData)
|
||||
expect_equal(expected, result)
|
||||
|
||||
# Single binary column
|
||||
input <- list(list(r1), list(r2), list(r3))
|
||||
expected <- subset(expected, select = "V2")
|
||||
result <- setNames(rbindRaws(input), "V2")
|
||||
expect_equal(expected, result)
|
||||
|
||||
})
|
||||
|
||||
sparkR.session.stop()
|
||||
|
|
|
@ -36,7 +36,14 @@ compute <- function(mode, partition, serializer, deserializer, key,
|
|||
# available since R 3.2.4. So we set the global option here.
|
||||
oldOpt <- getOption("stringsAsFactors")
|
||||
options(stringsAsFactors = FALSE)
|
||||
inputData <- do.call(rbind.data.frame, inputData)
|
||||
|
||||
# Handle binary data types
|
||||
if ("raw" %in% sapply(inputData[[1]], class)) {
|
||||
inputData <- SparkR:::rbindRaws(inputData)
|
||||
} else {
|
||||
inputData <- do.call(rbind.data.frame, inputData)
|
||||
}
|
||||
|
||||
options(stringsAsFactors = oldOpt)
|
||||
|
||||
names(inputData) <- colNames
|
||||
|
|
Loading…
Reference in a new issue