[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.

Changes include
1. Rename sortDF to arrange
2. Add new aliases `group_by` and `sample_frac`, `summarize`
3. Add more user friendly column addition (mutate), rename
4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr

Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax

The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it

cc sun-rui rxin

Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>

Closes #6005 from shivaram/sparkr-df-api and squashes the following commits:

5e0716a [Shivaram Venkataraman] Fix some roxygen bugs
1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api
0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
This commit is contained in:
Shivaram Venkataraman 2015-05-08 18:29:57 -07:00
parent b6c797b08c
commit 0a901dd3a1
8 changed files with 256 additions and 36 deletions

View file

@ -9,7 +9,8 @@ export("print.jobj")
exportClasses("DataFrame")
exportMethods("cache",
exportMethods("arrange",
"cache",
"collect",
"columns",
"count",
@ -20,6 +21,7 @@ exportMethods("cache",
"explain",
"filter",
"first",
"group_by",
"groupBy",
"head",
"insertInto",
@ -28,12 +30,15 @@ exportMethods("cache",
"join",
"limit",
"orderBy",
"mutate",
"names",
"persist",
"printSchema",
"registerTempTable",
"rename",
"repartition",
"sampleDF",
"sample_frac",
"saveAsParquetFile",
"saveAsTable",
"saveDF",
@ -42,7 +47,7 @@ exportMethods("cache",
"selectExpr",
"show",
"showDF",
"sortDF",
"summarize",
"take",
"unionAll",
"unpersist",
@ -72,6 +77,8 @@ exportMethods("abs",
"max",
"mean",
"min",
"n",
"n_distinct",
"rlike",
"sqrt",
"startsWith",

View file

@ -480,6 +480,7 @@ setMethod("distinct",
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
#' @rdname sampleDF
#' @aliases sample_frac
#' @export
#' @examples
#'\dontrun{
@ -501,6 +502,15 @@ setMethod("sampleDF",
dataFrame(sdf)
})
#' @rdname sampleDF
#' @aliases sampleDF
setMethod("sample_frac",
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
function(x, withReplacement, fraction) {
sampleDF(x, withReplacement, fraction)
})
#' Count
#'
#' Returns the number of rows in a DataFrame
@ -682,7 +692,8 @@ setMethod("toRDD",
#' @param x a DataFrame
#' @return a GroupedData
#' @seealso GroupedData
#' @rdname DataFrame
#' @aliases group_by
#' @rdname groupBy
#' @export
#' @examples
#' \dontrun{
@ -705,12 +716,21 @@ setMethod("groupBy",
groupedData(sgd)
})
#' Agg
#' @rdname groupBy
#' @aliases group_by
setMethod("group_by",
signature(x = "DataFrame"),
function(x, ...) {
groupBy(x, ...)
})
#' Summarize data across columns
#'
#' Compute aggregates by specifying a list of columns
#'
#' @param x a DataFrame
#' @rdname DataFrame
#' @aliases summarize
#' @export
setMethod("agg",
signature(x = "DataFrame"),
@ -718,6 +738,14 @@ setMethod("agg",
agg(groupBy(x), ...)
})
#' @rdname DataFrame
#' @aliases agg
setMethod("summarize",
signature(x = "DataFrame"),
function(x, ...) {
agg(x, ...)
})
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
@ -886,7 +914,7 @@ setMethod("select",
signature(x = "DataFrame", col = "list"),
function(x, col) {
cols <- lapply(col, function(c) {
if (class(c)== "Column") {
if (class(c) == "Column") {
c@jc
} else {
col(c)@jc
@ -946,6 +974,42 @@ setMethod("withColumn",
select(x, x$"*", alias(col, colName))
})
#' Mutate
#'
#' Return a new DataFrame with the specified columns added.
#'
#' @param x A DataFrame
#' @param col a named argument of the form name = col
#' @return A new DataFrame with the new columns added.
#' @rdname withColumn
#' @aliases withColumn
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
#' names(newDF) # Will contain newCol, newCol2
#' }
setMethod("mutate",
signature(x = "DataFrame"),
function(x, ...) {
cols <- list(...)
stopifnot(length(cols) > 0)
stopifnot(class(cols[[1]]) == "Column")
ns <- names(cols)
if (!is.null(ns)) {
for (n in ns) {
if (n != "") {
cols[[n]] <- alias(cols[[n]], n)
}
}
}
do.call(select, c(x, x$"*", cols))
})
#' WithColumnRenamed
#'
#' Rename an existing column in a DataFrame.
@ -977,17 +1041,15 @@ setMethod("withColumnRenamed",
select(x, cols)
})
setClassUnion("characterOrColumn", c("character", "Column"))
#' SortDF
#' Rename
#'
#' Sort a DataFrame by the specified column(s).
#' Rename an existing column in a DataFrame.
#'
#' @param x A DataFrame to be sorted.
#' @param col Either a Column object or character vector indicating the field to sort on
#' @param ... Additional sorting fields
#' @return A DataFrame where all elements are sorted.
#' @rdname sortDF
#' @param x A DataFrame
#' @param newCol A named pair of the form new_column_name = existing_column
#' @return A DataFrame with the column name changed.
#' @rdname withColumnRenamed
#' @aliases withColumnRenamed
#' @export
#' @examples
#'\dontrun{
@ -995,11 +1057,51 @@ setClassUnion("characterOrColumn", c("character", "Column"))
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' sortDF(df, df$col1)
#' sortDF(df, "col1")
#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
#' newDF <- rename(df, col1 = df$newCol1)
#' }
setMethod("sortDF",
setMethod("rename",
signature(x = "DataFrame"),
function(x, ...) {
renameCols <- list(...)
stopifnot(length(renameCols) > 0)
stopifnot(class(renameCols[[1]]) == "Column")
newNames <- names(renameCols)
oldNames <- lapply(renameCols, function(col) {
callJMethod(col@jc, "toString")
})
cols <- lapply(columns(x), function(c) {
if (c %in% oldNames) {
alias(col(c), newNames[[match(c, oldNames)]])
} else {
col(c)
}
})
select(x, cols)
})
setClassUnion("characterOrColumn", c("character", "Column"))
#' Arrange
#'
#' Sort a DataFrame by the specified column(s).
#'
#' @param x A DataFrame to be sorted.
#' @param col Either a Column object or character vector indicating the field to sort on
#' @param ... Additional sorting fields
#' @return A DataFrame where all elements are sorted.
#' @rdname arrange
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
#' arrange(df, df$col1)
#' arrange(df, "col1")
#' arrange(df, asc(df$col1), desc(abs(df$col2)))
#' }
setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
@ -1013,12 +1115,12 @@ setMethod("sortDF",
dataFrame(sdf)
})
#' @rdname sortDF
#' @rdname arrange
#' @aliases orderBy,DataFrame,function-method
setMethod("orderBy",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col) {
sortDF(x, col)
arrange(x, col)
})
#' Filter
@ -1026,7 +1128,7 @@ setMethod("orderBy",
#' Filter the rows of a DataFrame according to a given condition.
#'
#' @param x A DataFrame to be sorted.
#' @param condition The condition to sort on. This may either be a Column expression
#' @param condition The condition to filter on. This may either be a Column expression
#' or a string containing a SQL statement
#' @return A DataFrame containing only the rows that meet the condition.
#' @rdname filter
@ -1106,6 +1208,7 @@ setMethod("join",
#'
#' Return a new DataFrame containing the union of rows in this DataFrame
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
#' Note that this does not remove duplicate rows across the two DataFrames.
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame

View file

@ -131,6 +131,8 @@ createMethods()
#' alias
#'
#' Set a new name for a column
#' @rdname column
setMethod("alias",
signature(object = "Column"),
function(object, data) {
@ -141,8 +143,12 @@ setMethod("alias",
}
})
#' substr
#'
#' An expression that returns a substring.
#'
#' @rdname column
#'
#' @param start starting position
#' @param stop ending position
setMethod("substr", signature(x = "Column"),
@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
})
#' Casts the column to a different data type.
#'
#' @rdname column
#'
#' @examples
#' \dontrun{
#' cast(df$age, "string")
@ -173,8 +182,8 @@ setMethod("cast",
#' Approx Count Distinct
#'
#' Returns the approximate number of distinct items in a group.
#'
#' @rdname column
#' @return the approximate number of distinct items in a group.
setMethod("approxCountDistinct",
signature(x = "Column"),
function(x, rsd = 0.95) {
@ -184,8 +193,8 @@ setMethod("approxCountDistinct",
#' Count Distinct
#'
#' returns the number of distinct items in a group.
#'
#' @rdname column
#' @return the number of distinct items in a group.
setMethod("countDistinct",
signature(x = "Column"),
function(x, ...) {
@ -197,3 +206,18 @@ setMethod("countDistinct",
column(jc)
})
#' @rdname column
#' @aliases countDistinct
setMethod("n_distinct",
signature(x = "Column"),
function(x, ...) {
countDistinct(x, ...)
})
#' @rdname column
#' @aliases count
setMethod("n",
signature(x = "Column"),
function(x) {
count(x)
})

View file

@ -380,6 +380,14 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
#################### DataFrame Methods ########################
#' @rdname agg
#' @export
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
#' @rdname arrange
#' @export
setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
#' @rdname schema
#' @export
setGeneric("columns", function(x) {standardGeneric("columns") })
@ -404,6 +412,10 @@ setGeneric("except", function(x, y) { standardGeneric("except") })
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
#' @rdname groupBy
#' @export
setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
#' @rdname DataFrame
#' @export
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
@ -424,7 +436,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
#' @export
setGeneric("limit", function(x, num) {standardGeneric("limit") })
#' @rdname sortDF
#' @rdname withColumn
#' @export
setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
#' @rdname arrange
#' @export
setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
@ -432,10 +448,21 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
#' @export
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
#' @rdname withColumnRenamed
#' @export
setGeneric("rename", function(x, ...) { standardGeneric("rename") })
#' @rdname registerTempTable
#' @export
setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
#' @rdname sampleDF
#' @export
setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) {
standardGeneric("sample_frac")
})
#' @rdname sampleDF
#' @export
setGeneric("sampleDF",
@ -473,9 +500,9 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
#' @export
setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
#' @rdname sortDF
#' @rdname agg
#' @export
setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
# @rdname tojson
# @export
@ -564,6 +591,14 @@ setGeneric("like", function(x, ...) { standardGeneric("like") })
#' @export
setGeneric("lower", function(x) { standardGeneric("lower") })
#' @rdname column
#' @export
setGeneric("n", function(x) { standardGeneric("n") })
#' @rdname column
#' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
#' @rdname column
#' @export
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })

View file

@ -56,6 +56,7 @@ setMethod("show", "GroupedData",
#'
#' @param x a GroupedData
#' @return a DataFrame
#' @rdname agg
#' @export
#' @examples
#' \dontrun{
@ -83,8 +84,6 @@ setMethod("count",
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
#' }
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
setMethod("agg",
signature(x = "GroupedData"),
function(x, ...) {
@ -112,6 +111,13 @@ setMethod("agg",
dataFrame(sdf)
})
#' @rdname agg
#' @aliases agg
setMethod("summarize",
signature(x = "GroupedData"),
function(x, ...) {
agg(x, ...)
})
# sum/mean/avg/min/max
methods <- c("sum", "mean", "avg", "min", "max")

View file

@ -428,6 +428,10 @@ test_that("sampleDF on a DataFrame", {
expect_true(inherits(sampled, "DataFrame"))
sampled2 <- sampleDF(df, FALSE, 0.1)
expect_true(count(sampled2) < 3)
# Also test sample_frac
sampled3 <- sample_frac(df, FALSE, 0.1)
expect_true(count(sampled3) < 3)
})
test_that("select operators", {
@ -533,6 +537,7 @@ test_that("column functions", {
c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
c5 <- n(c) + n_distinct(c)
})
test_that("string operators", {
@ -557,6 +562,13 @@ test_that("group by", {
expect_true(inherits(df2, "DataFrame"))
expect_true(3 == count(df2))
# Also test group_by, summarize, mean
gd1 <- group_by(df, "name")
expect_true(inherits(gd1, "GroupedData"))
df_summarized <- summarize(gd, mean_age = mean(df$age))
expect_true(inherits(df_summarized, "DataFrame"))
expect_true(3 == count(df_summarized))
df3 <- agg(gd, age = "sum")
expect_true(inherits(df3, "DataFrame"))
expect_true(3 == count(df3))
@ -573,12 +585,12 @@ test_that("group by", {
expect_true(3 == count(max(gd, "age")))
})
test_that("sortDF() and orderBy() on a DataFrame", {
test_that("arrange() and orderBy() on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
sorted <- sortDF(df, df$age)
sorted <- arrange(df, df$age)
expect_true(collect(sorted)[1,2] == "Michael")
sorted2 <- sortDF(df, "name")
sorted2 <- arrange(df, "name")
expect_true(collect(sorted2)[2,"age"] == 19)
sorted3 <- orderBy(df, asc(df$age))
@ -659,17 +671,17 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
writeLines(lines, jsonPath2)
df2 <- loadDF(sqlCtx, jsonPath2, "json")
unioned <- sortDF(unionAll(df, df2), df$age)
unioned <- arrange(unionAll(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(unioned) == 6)
expect_true(first(unioned)$name == "Michael")
excepted <- sortDF(except(df, df2), desc(df$age))
excepted <- arrange(except(df, df2), desc(df$age))
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(excepted) == 2)
expect_true(first(excepted)$name == "Justin")
intersected <- sortDF(intersect(df, df2), df$age)
intersected <- arrange(intersect(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
expect_true(count(intersected) == 1)
expect_true(first(intersected)$name == "Andy")
@ -687,6 +699,18 @@ test_that("withColumn() and withColumnRenamed()", {
expect_true(columns(newDF2)[1] == "newerAge")
})
test_that("mutate() and rename()", {
df <- jsonFile(sqlCtx, jsonPath)
newDF <- mutate(df, newAge = df$age + 2)
expect_true(length(columns(newDF)) == 3)
expect_true(columns(newDF)[3] == "newAge")
expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
newDF2 <- rename(df, newerAge = df$age)
expect_true(length(columns(newDF2)) == 2)
expect_true(columns(newDF2)[1] == "newerAge")
})
test_that("saveDF() on DataFrame and works with parquetFile", {
df <- jsonFile(sqlCtx, jsonPath)
saveDF(df, parquetPath, "parquet", mode="overwrite")

View file

@ -246,6 +246,22 @@ object functions {
*/
def last(columnName: String): Column = last(Column(columnName))
/**
* Aggregate function: returns the average of the values in a group.
* Alias for avg.
*
* @group agg_funcs
*/
def mean(e: Column): Column = avg(e)
/**
* Aggregate function: returns the average of the values in a group.
* Alias for avg.
*
* @group agg_funcs
*/
def mean(columnName: String): Column = avg(columnName)
/**
* Aggregate function: returns the minimum value of the expression in a group.
*

View file

@ -308,6 +308,11 @@ class DataFrameSuite extends QueryTest {
testData2.agg(avg('a)),
Row(2.0))
// Also check mean
checkAnswer(
testData2.agg(mean('a)),
Row(2.0))
checkAnswer(
testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)