[SPARK-16005][R] Add randomSplit
to SparkR
## What changes were proposed in this pull request? This PR adds `randomSplit` to SparkR for API parity. ## How was this patch tested? Pass the Jenkins tests (with new testcase.) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13721 from dongjoon-hyun/SPARK-16005.
This commit is contained in:
parent
ef3cc4fc09
commit
7d65a0db4a
|
@ -81,6 +81,7 @@ exportMethods("arrange",
|
|||
"orderBy",
|
||||
"persist",
|
||||
"printSchema",
|
||||
"randomSplit",
|
||||
"rbind",
|
||||
"registerTempTable",
|
||||
"rename",
|
||||
|
|
|
@ -2934,3 +2934,40 @@ setMethod("write.jdbc",
|
|||
write <- callJMethod(write, "mode", jmode)
|
||||
invisible(callJMethod(write, "jdbc", url, tableName, jprops))
|
||||
})
|
||||
|
||||
#' randomSplit
|
||||
#'
|
||||
#' Return a list of randomly split dataframes with the provided weights.
|
||||
#'
|
||||
#' @param x A SparkDataFrame
|
||||
#' @param weights A vector of weights for splits, will be normalized if they don't sum to 1
|
||||
#' @param seed A seed to use for random split
|
||||
#'
|
||||
#' @family SparkDataFrame functions
|
||||
#' @rdname randomSplit
|
||||
#' @name randomSplit
|
||||
#' @export
|
||||
#' @examples
|
||||
#'\dontrun{
|
||||
#' sc <- sparkR.init()
|
||||
#' sqlContext <- sparkRSQL.init(sc)
|
||||
#' df <- createDataFrame(data.frame(id = 1:1000))
|
||||
#' df_list <- randomSplit(df, c(2, 3, 5), 0)
|
||||
#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 500 rows respectively
|
||||
#' sapply(df_list, count)
|
||||
#' }
|
||||
#' @note since 2.0.0
|
||||
setMethod("randomSplit",
|
||||
signature(x = "SparkDataFrame", weights = "numeric"),
|
||||
function(x, weights, seed) {
|
||||
if (!all(sapply(weights, function(c) { c >= 0 }))) {
|
||||
stop("all weight values should not be negative")
|
||||
}
|
||||
normalized_list <- as.list(weights / sum(weights))
|
||||
if (!missing(seed)) {
|
||||
sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, as.integer(seed))
|
||||
} else {
|
||||
sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list)
|
||||
}
|
||||
sapply(sdfs, dataFrame)
|
||||
})
|
||||
|
|
|
@ -679,6 +679,10 @@ setGeneric("withColumnRenamed",
|
|||
#' @export
|
||||
setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
|
||||
|
||||
#' @rdname randomSplit
|
||||
#' @export
|
||||
setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") })
|
||||
|
||||
###################### Column Methods ##########################
|
||||
|
||||
#' @rdname column
|
||||
|
|
|
@ -2280,6 +2280,24 @@ test_that("createDataFrame sqlContext parameter backward compatibility", {
|
|||
expect_equal(collect(before), collect(after))
|
||||
})
|
||||
|
||||
test_that("randomSplit", {
|
||||
num <- 4000
|
||||
df <- createDataFrame(data.frame(id = 1:num))
|
||||
|
||||
weights <- c(2, 3, 5)
|
||||
df_list <- randomSplit(df, weights)
|
||||
expect_equal(length(weights), length(df_list))
|
||||
counts <- sapply(df_list, count)
|
||||
expect_equal(num, sum(counts))
|
||||
expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
|
||||
|
||||
df_list <- randomSplit(df, weights, 0)
|
||||
expect_equal(length(weights), length(df_list))
|
||||
counts <- sapply(df_list, count)
|
||||
expect_equal(num, sum(counts))
|
||||
expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
|
||||
})
|
||||
|
||||
unlink(parquetPath)
|
||||
unlink(jsonPath)
|
||||
unlink(jsonPathNa)
|
||||
|
|
Loading…
Reference in a new issue