From 7d65a0db4a231882200513836f2720f59b35f364 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 17 Jun 2016 16:07:33 -0700 Subject: [PATCH] [SPARK-16005][R] Add `randomSplit` to SparkR ## What changes were proposed in this pull request? This PR adds `randomSplit` to SparkR for API parity. ## How was this patch tested? Pass the Jenkins tests (with new testcase.) Author: Dongjoon Hyun Closes #13721 from dongjoon-hyun/SPARK-16005. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 37 +++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 18 +++++++++++ 4 files changed, 60 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5db43ae649..9412ec3f9e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -81,6 +81,7 @@ exportMethods("arrange", "orderBy", "persist", "printSchema", + "randomSplit", "rbind", "registerTempTable", "rename", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 231e4f0f4e..4e044565f4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2934,3 +2934,40 @@ setMethod("write.jdbc", write <- callJMethod(write, "mode", jmode) invisible(callJMethod(write, "jdbc", url, tableName, jprops)) }) + +#' randomSplit +#' +#' Return a list of randomly split dataframes with the provided weights. +#' +#' @param x A SparkDataFrame +#' @param weights A vector of weights for splits, will be normalized if they don't sum to 1 +#' @param seed A seed to use for random split +#' +#' @family SparkDataFrame functions +#' @rdname randomSplit +#' @name randomSplit +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- createDataFrame(data.frame(id = 1:1000)) +#' df_list <- randomSplit(df, c(2, 3, 5), 0) +#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 500 rows respectively +#' sapply(df_list, count) +#' } +#' @note since 2.0.0 +setMethod("randomSplit", + signature(x = "SparkDataFrame", weights = "numeric"), + function(x, weights, seed) { + if (!all(sapply(weights, function(c) { c >= 0 }))) { + stop("all weight values should not be negative") + } + normalized_list <- as.list(weights / sum(weights)) + if (!missing(seed)) { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, as.integer(seed)) + } else { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list) + } + sapply(sdfs, dataFrame) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 594bf2eadc..6e754afab6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -679,6 +679,10 @@ setGeneric("withColumnRenamed", #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +#' @rdname randomSplit +#' @export +setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) + ###################### Column Methods ########################## #' @rdname column diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7aa03a9048..607bd9c12f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2280,6 +2280,24 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { expect_equal(collect(before), collect(after)) }) +test_that("randomSplit", { + num <- 4000 + df <- createDataFrame(data.frame(id = 1:num)) + + weights <- c(2, 3, 5) + df_list <- randomSplit(df, weights) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) + + df_list <- randomSplit(df, weights, 0) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa)