[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.
Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it cc sun-rui rxin Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu> Closes #6005 from shivaram/sparkr-df-api and squashes the following commits: 5e0716a [Shivaram Venkataraman] Fix some roxygen bugs 1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api 0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
This commit is contained in:
parent
b6c797b08c
commit
0a901dd3a1
|
@ -9,7 +9,8 @@ export("print.jobj")
|
||||||
|
|
||||||
exportClasses("DataFrame")
|
exportClasses("DataFrame")
|
||||||
|
|
||||||
exportMethods("cache",
|
exportMethods("arrange",
|
||||||
|
"cache",
|
||||||
"collect",
|
"collect",
|
||||||
"columns",
|
"columns",
|
||||||
"count",
|
"count",
|
||||||
|
@ -20,6 +21,7 @@ exportMethods("cache",
|
||||||
"explain",
|
"explain",
|
||||||
"filter",
|
"filter",
|
||||||
"first",
|
"first",
|
||||||
|
"group_by",
|
||||||
"groupBy",
|
"groupBy",
|
||||||
"head",
|
"head",
|
||||||
"insertInto",
|
"insertInto",
|
||||||
|
@ -28,12 +30,15 @@ exportMethods("cache",
|
||||||
"join",
|
"join",
|
||||||
"limit",
|
"limit",
|
||||||
"orderBy",
|
"orderBy",
|
||||||
|
"mutate",
|
||||||
"names",
|
"names",
|
||||||
"persist",
|
"persist",
|
||||||
"printSchema",
|
"printSchema",
|
||||||
"registerTempTable",
|
"registerTempTable",
|
||||||
|
"rename",
|
||||||
"repartition",
|
"repartition",
|
||||||
"sampleDF",
|
"sampleDF",
|
||||||
|
"sample_frac",
|
||||||
"saveAsParquetFile",
|
"saveAsParquetFile",
|
||||||
"saveAsTable",
|
"saveAsTable",
|
||||||
"saveDF",
|
"saveDF",
|
||||||
|
@ -42,7 +47,7 @@ exportMethods("cache",
|
||||||
"selectExpr",
|
"selectExpr",
|
||||||
"show",
|
"show",
|
||||||
"showDF",
|
"showDF",
|
||||||
"sortDF",
|
"summarize",
|
||||||
"take",
|
"take",
|
||||||
"unionAll",
|
"unionAll",
|
||||||
"unpersist",
|
"unpersist",
|
||||||
|
@ -72,6 +77,8 @@ exportMethods("abs",
|
||||||
"max",
|
"max",
|
||||||
"mean",
|
"mean",
|
||||||
"min",
|
"min",
|
||||||
|
"n",
|
||||||
|
"n_distinct",
|
||||||
"rlike",
|
"rlike",
|
||||||
"sqrt",
|
"sqrt",
|
||||||
"startsWith",
|
"startsWith",
|
||||||
|
|
|
@ -480,6 +480,7 @@ setMethod("distinct",
|
||||||
#' @param withReplacement Sampling with replacement or not
|
#' @param withReplacement Sampling with replacement or not
|
||||||
#' @param fraction The (rough) sample target fraction
|
#' @param fraction The (rough) sample target fraction
|
||||||
#' @rdname sampleDF
|
#' @rdname sampleDF
|
||||||
|
#' @aliases sample_frac
|
||||||
#' @export
|
#' @export
|
||||||
#' @examples
|
#' @examples
|
||||||
#'\dontrun{
|
#'\dontrun{
|
||||||
|
@ -501,6 +502,15 @@ setMethod("sampleDF",
|
||||||
dataFrame(sdf)
|
dataFrame(sdf)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' @rdname sampleDF
|
||||||
|
#' @aliases sampleDF
|
||||||
|
setMethod("sample_frac",
|
||||||
|
signature(x = "DataFrame", withReplacement = "logical",
|
||||||
|
fraction = "numeric"),
|
||||||
|
function(x, withReplacement, fraction) {
|
||||||
|
sampleDF(x, withReplacement, fraction)
|
||||||
|
})
|
||||||
|
|
||||||
#' Count
|
#' Count
|
||||||
#'
|
#'
|
||||||
#' Returns the number of rows in a DataFrame
|
#' Returns the number of rows in a DataFrame
|
||||||
|
@ -682,7 +692,8 @@ setMethod("toRDD",
|
||||||
#' @param x a DataFrame
|
#' @param x a DataFrame
|
||||||
#' @return a GroupedData
|
#' @return a GroupedData
|
||||||
#' @seealso GroupedData
|
#' @seealso GroupedData
|
||||||
#' @rdname DataFrame
|
#' @aliases group_by
|
||||||
|
#' @rdname groupBy
|
||||||
#' @export
|
#' @export
|
||||||
#' @examples
|
#' @examples
|
||||||
#' \dontrun{
|
#' \dontrun{
|
||||||
|
@ -705,12 +716,21 @@ setMethod("groupBy",
|
||||||
groupedData(sgd)
|
groupedData(sgd)
|
||||||
})
|
})
|
||||||
|
|
||||||
#' Agg
|
#' @rdname groupBy
|
||||||
|
#' @aliases group_by
|
||||||
|
setMethod("group_by",
|
||||||
|
signature(x = "DataFrame"),
|
||||||
|
function(x, ...) {
|
||||||
|
groupBy(x, ...)
|
||||||
|
})
|
||||||
|
|
||||||
|
#' Summarize data across columns
|
||||||
#'
|
#'
|
||||||
#' Compute aggregates by specifying a list of columns
|
#' Compute aggregates by specifying a list of columns
|
||||||
#'
|
#'
|
||||||
#' @param x a DataFrame
|
#' @param x a DataFrame
|
||||||
#' @rdname DataFrame
|
#' @rdname DataFrame
|
||||||
|
#' @aliases summarize
|
||||||
#' @export
|
#' @export
|
||||||
setMethod("agg",
|
setMethod("agg",
|
||||||
signature(x = "DataFrame"),
|
signature(x = "DataFrame"),
|
||||||
|
@ -718,6 +738,14 @@ setMethod("agg",
|
||||||
agg(groupBy(x), ...)
|
agg(groupBy(x), ...)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' @rdname DataFrame
|
||||||
|
#' @aliases agg
|
||||||
|
setMethod("summarize",
|
||||||
|
signature(x = "DataFrame"),
|
||||||
|
function(x, ...) {
|
||||||
|
agg(x, ...)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
############################## RDD Map Functions ##################################
|
############################## RDD Map Functions ##################################
|
||||||
# All of the following functions mirror the existing RDD map functions, #
|
# All of the following functions mirror the existing RDD map functions, #
|
||||||
|
@ -886,7 +914,7 @@ setMethod("select",
|
||||||
signature(x = "DataFrame", col = "list"),
|
signature(x = "DataFrame", col = "list"),
|
||||||
function(x, col) {
|
function(x, col) {
|
||||||
cols <- lapply(col, function(c) {
|
cols <- lapply(col, function(c) {
|
||||||
if (class(c)== "Column") {
|
if (class(c) == "Column") {
|
||||||
c@jc
|
c@jc
|
||||||
} else {
|
} else {
|
||||||
col(c)@jc
|
col(c)@jc
|
||||||
|
@ -946,6 +974,42 @@ setMethod("withColumn",
|
||||||
select(x, x$"*", alias(col, colName))
|
select(x, x$"*", alias(col, colName))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' Mutate
|
||||||
|
#'
|
||||||
|
#' Return a new DataFrame with the specified columns added.
|
||||||
|
#'
|
||||||
|
#' @param x A DataFrame
|
||||||
|
#' @param col a named argument of the form name = col
|
||||||
|
#' @return A new DataFrame with the new columns added.
|
||||||
|
#' @rdname withColumn
|
||||||
|
#' @aliases withColumn
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#'\dontrun{
|
||||||
|
#' sc <- sparkR.init()
|
||||||
|
#' sqlCtx <- sparkRSQL.init(sc)
|
||||||
|
#' path <- "path/to/file.json"
|
||||||
|
#' df <- jsonFile(sqlCtx, path)
|
||||||
|
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
|
||||||
|
#' names(newDF) # Will contain newCol, newCol2
|
||||||
|
#' }
|
||||||
|
setMethod("mutate",
|
||||||
|
signature(x = "DataFrame"),
|
||||||
|
function(x, ...) {
|
||||||
|
cols <- list(...)
|
||||||
|
stopifnot(length(cols) > 0)
|
||||||
|
stopifnot(class(cols[[1]]) == "Column")
|
||||||
|
ns <- names(cols)
|
||||||
|
if (!is.null(ns)) {
|
||||||
|
for (n in ns) {
|
||||||
|
if (n != "") {
|
||||||
|
cols[[n]] <- alias(cols[[n]], n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
do.call(select, c(x, x$"*", cols))
|
||||||
|
})
|
||||||
|
|
||||||
#' WithColumnRenamed
|
#' WithColumnRenamed
|
||||||
#'
|
#'
|
||||||
#' Rename an existing column in a DataFrame.
|
#' Rename an existing column in a DataFrame.
|
||||||
|
@ -977,17 +1041,15 @@ setMethod("withColumnRenamed",
|
||||||
select(x, cols)
|
select(x, cols)
|
||||||
})
|
})
|
||||||
|
|
||||||
setClassUnion("characterOrColumn", c("character", "Column"))
|
#' Rename
|
||||||
|
|
||||||
#' SortDF
|
|
||||||
#'
|
#'
|
||||||
#' Sort a DataFrame by the specified column(s).
|
#' Rename an existing column in a DataFrame.
|
||||||
#'
|
#'
|
||||||
#' @param x A DataFrame to be sorted.
|
#' @param x A DataFrame
|
||||||
#' @param col Either a Column object or character vector indicating the field to sort on
|
#' @param newCol A named pair of the form new_column_name = existing_column
|
||||||
#' @param ... Additional sorting fields
|
#' @return A DataFrame with the column name changed.
|
||||||
#' @return A DataFrame where all elements are sorted.
|
#' @rdname withColumnRenamed
|
||||||
#' @rdname sortDF
|
#' @aliases withColumnRenamed
|
||||||
#' @export
|
#' @export
|
||||||
#' @examples
|
#' @examples
|
||||||
#'\dontrun{
|
#'\dontrun{
|
||||||
|
@ -995,11 +1057,51 @@ setClassUnion("characterOrColumn", c("character", "Column"))
|
||||||
#' sqlCtx <- sparkRSQL.init(sc)
|
#' sqlCtx <- sparkRSQL.init(sc)
|
||||||
#' path <- "path/to/file.json"
|
#' path <- "path/to/file.json"
|
||||||
#' df <- jsonFile(sqlCtx, path)
|
#' df <- jsonFile(sqlCtx, path)
|
||||||
#' sortDF(df, df$col1)
|
#' newDF <- rename(df, col1 = df$newCol1)
|
||||||
#' sortDF(df, "col1")
|
|
||||||
#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
|
|
||||||
#' }
|
#' }
|
||||||
setMethod("sortDF",
|
setMethod("rename",
|
||||||
|
signature(x = "DataFrame"),
|
||||||
|
function(x, ...) {
|
||||||
|
renameCols <- list(...)
|
||||||
|
stopifnot(length(renameCols) > 0)
|
||||||
|
stopifnot(class(renameCols[[1]]) == "Column")
|
||||||
|
newNames <- names(renameCols)
|
||||||
|
oldNames <- lapply(renameCols, function(col) {
|
||||||
|
callJMethod(col@jc, "toString")
|
||||||
|
})
|
||||||
|
cols <- lapply(columns(x), function(c) {
|
||||||
|
if (c %in% oldNames) {
|
||||||
|
alias(col(c), newNames[[match(c, oldNames)]])
|
||||||
|
} else {
|
||||||
|
col(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
select(x, cols)
|
||||||
|
})
|
||||||
|
|
||||||
|
setClassUnion("characterOrColumn", c("character", "Column"))
|
||||||
|
|
||||||
|
#' Arrange
|
||||||
|
#'
|
||||||
|
#' Sort a DataFrame by the specified column(s).
|
||||||
|
#'
|
||||||
|
#' @param x A DataFrame to be sorted.
|
||||||
|
#' @param col Either a Column object or character vector indicating the field to sort on
|
||||||
|
#' @param ... Additional sorting fields
|
||||||
|
#' @return A DataFrame where all elements are sorted.
|
||||||
|
#' @rdname arrange
|
||||||
|
#' @export
|
||||||
|
#' @examples
|
||||||
|
#'\dontrun{
|
||||||
|
#' sc <- sparkR.init()
|
||||||
|
#' sqlCtx <- sparkRSQL.init(sc)
|
||||||
|
#' path <- "path/to/file.json"
|
||||||
|
#' df <- jsonFile(sqlCtx, path)
|
||||||
|
#' arrange(df, df$col1)
|
||||||
|
#' arrange(df, "col1")
|
||||||
|
#' arrange(df, asc(df$col1), desc(abs(df$col2)))
|
||||||
|
#' }
|
||||||
|
setMethod("arrange",
|
||||||
signature(x = "DataFrame", col = "characterOrColumn"),
|
signature(x = "DataFrame", col = "characterOrColumn"),
|
||||||
function(x, col, ...) {
|
function(x, col, ...) {
|
||||||
if (class(col) == "character") {
|
if (class(col) == "character") {
|
||||||
|
@ -1013,12 +1115,12 @@ setMethod("sortDF",
|
||||||
dataFrame(sdf)
|
dataFrame(sdf)
|
||||||
})
|
})
|
||||||
|
|
||||||
#' @rdname sortDF
|
#' @rdname arrange
|
||||||
#' @aliases orderBy,DataFrame,function-method
|
#' @aliases orderBy,DataFrame,function-method
|
||||||
setMethod("orderBy",
|
setMethod("orderBy",
|
||||||
signature(x = "DataFrame", col = "characterOrColumn"),
|
signature(x = "DataFrame", col = "characterOrColumn"),
|
||||||
function(x, col) {
|
function(x, col) {
|
||||||
sortDF(x, col)
|
arrange(x, col)
|
||||||
})
|
})
|
||||||
|
|
||||||
#' Filter
|
#' Filter
|
||||||
|
@ -1026,7 +1128,7 @@ setMethod("orderBy",
|
||||||
#' Filter the rows of a DataFrame according to a given condition.
|
#' Filter the rows of a DataFrame according to a given condition.
|
||||||
#'
|
#'
|
||||||
#' @param x A DataFrame to be sorted.
|
#' @param x A DataFrame to be sorted.
|
||||||
#' @param condition The condition to sort on. This may either be a Column expression
|
#' @param condition The condition to filter on. This may either be a Column expression
|
||||||
#' or a string containing a SQL statement
|
#' or a string containing a SQL statement
|
||||||
#' @return A DataFrame containing only the rows that meet the condition.
|
#' @return A DataFrame containing only the rows that meet the condition.
|
||||||
#' @rdname filter
|
#' @rdname filter
|
||||||
|
@ -1106,6 +1208,7 @@ setMethod("join",
|
||||||
#'
|
#'
|
||||||
#' Return a new DataFrame containing the union of rows in this DataFrame
|
#' Return a new DataFrame containing the union of rows in this DataFrame
|
||||||
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
|
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
|
||||||
|
#' Note that this does not remove duplicate rows across the two DataFrames.
|
||||||
#'
|
#'
|
||||||
#' @param x A Spark DataFrame
|
#' @param x A Spark DataFrame
|
||||||
#' @param y A Spark DataFrame
|
#' @param y A Spark DataFrame
|
||||||
|
|
|
@ -131,6 +131,8 @@ createMethods()
|
||||||
#' alias
|
#' alias
|
||||||
#'
|
#'
|
||||||
#' Set a new name for a column
|
#' Set a new name for a column
|
||||||
|
|
||||||
|
#' @rdname column
|
||||||
setMethod("alias",
|
setMethod("alias",
|
||||||
signature(object = "Column"),
|
signature(object = "Column"),
|
||||||
function(object, data) {
|
function(object, data) {
|
||||||
|
@ -141,8 +143,12 @@ setMethod("alias",
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' substr
|
||||||
|
#'
|
||||||
#' An expression that returns a substring.
|
#' An expression that returns a substring.
|
||||||
#'
|
#'
|
||||||
|
#' @rdname column
|
||||||
|
#'
|
||||||
#' @param start starting position
|
#' @param start starting position
|
||||||
#' @param stop ending position
|
#' @param stop ending position
|
||||||
setMethod("substr", signature(x = "Column"),
|
setMethod("substr", signature(x = "Column"),
|
||||||
|
@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
|
||||||
})
|
})
|
||||||
|
|
||||||
#' Casts the column to a different data type.
|
#' Casts the column to a different data type.
|
||||||
|
#'
|
||||||
|
#' @rdname column
|
||||||
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' \dontrun{
|
#' \dontrun{
|
||||||
#' cast(df$age, "string")
|
#' cast(df$age, "string")
|
||||||
|
@ -173,8 +182,8 @@ setMethod("cast",
|
||||||
|
|
||||||
#' Approx Count Distinct
|
#' Approx Count Distinct
|
||||||
#'
|
#'
|
||||||
#' Returns the approximate number of distinct items in a group.
|
#' @rdname column
|
||||||
#'
|
#' @return the approximate number of distinct items in a group.
|
||||||
setMethod("approxCountDistinct",
|
setMethod("approxCountDistinct",
|
||||||
signature(x = "Column"),
|
signature(x = "Column"),
|
||||||
function(x, rsd = 0.95) {
|
function(x, rsd = 0.95) {
|
||||||
|
@ -184,8 +193,8 @@ setMethod("approxCountDistinct",
|
||||||
|
|
||||||
#' Count Distinct
|
#' Count Distinct
|
||||||
#'
|
#'
|
||||||
#' returns the number of distinct items in a group.
|
#' @rdname column
|
||||||
#'
|
#' @return the number of distinct items in a group.
|
||||||
setMethod("countDistinct",
|
setMethod("countDistinct",
|
||||||
signature(x = "Column"),
|
signature(x = "Column"),
|
||||||
function(x, ...) {
|
function(x, ...) {
|
||||||
|
@ -197,3 +206,18 @@ setMethod("countDistinct",
|
||||||
column(jc)
|
column(jc)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' @rdname column
|
||||||
|
#' @aliases countDistinct
|
||||||
|
setMethod("n_distinct",
|
||||||
|
signature(x = "Column"),
|
||||||
|
function(x, ...) {
|
||||||
|
countDistinct(x, ...)
|
||||||
|
})
|
||||||
|
|
||||||
|
#' @rdname column
|
||||||
|
#' @aliases count
|
||||||
|
setMethod("n",
|
||||||
|
signature(x = "Column"),
|
||||||
|
function(x) {
|
||||||
|
count(x)
|
||||||
|
})
|
||||||
|
|
|
@ -380,6 +380,14 @@ setGeneric("value", function(bcast) { standardGeneric("value") })
|
||||||
|
|
||||||
#################### DataFrame Methods ########################
|
#################### DataFrame Methods ########################
|
||||||
|
|
||||||
|
#' @rdname agg
|
||||||
|
#' @export
|
||||||
|
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
|
||||||
|
|
||||||
|
#' @rdname arrange
|
||||||
|
#' @export
|
||||||
|
setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
|
||||||
|
|
||||||
#' @rdname schema
|
#' @rdname schema
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("columns", function(x) {standardGeneric("columns") })
|
setGeneric("columns", function(x) {standardGeneric("columns") })
|
||||||
|
@ -404,6 +412,10 @@ setGeneric("except", function(x, y) { standardGeneric("except") })
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
|
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
|
||||||
|
|
||||||
|
#' @rdname groupBy
|
||||||
|
#' @export
|
||||||
|
setGeneric("group_by", function(x, ...) { standardGeneric("group_by") })
|
||||||
|
|
||||||
#' @rdname DataFrame
|
#' @rdname DataFrame
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
|
setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
|
||||||
|
@ -424,7 +436,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("limit", function(x, num) {standardGeneric("limit") })
|
setGeneric("limit", function(x, num) {standardGeneric("limit") })
|
||||||
|
|
||||||
#' @rdname sortDF
|
#' @rdname withColumn
|
||||||
|
#' @export
|
||||||
|
setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
|
||||||
|
|
||||||
|
#' @rdname arrange
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
|
setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
|
||||||
|
|
||||||
|
@ -432,10 +448,21 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
|
setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
|
||||||
|
|
||||||
|
#' @rdname withColumnRenamed
|
||||||
|
#' @export
|
||||||
|
setGeneric("rename", function(x, ...) { standardGeneric("rename") })
|
||||||
|
|
||||||
#' @rdname registerTempTable
|
#' @rdname registerTempTable
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
|
setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
|
||||||
|
|
||||||
|
#' @rdname sampleDF
|
||||||
|
#' @export
|
||||||
|
setGeneric("sample_frac",
|
||||||
|
function(x, withReplacement, fraction, seed) {
|
||||||
|
standardGeneric("sample_frac")
|
||||||
|
})
|
||||||
|
|
||||||
#' @rdname sampleDF
|
#' @rdname sampleDF
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("sampleDF",
|
setGeneric("sampleDF",
|
||||||
|
@ -473,9 +500,9 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr")
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
|
setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
|
||||||
|
|
||||||
#' @rdname sortDF
|
#' @rdname agg
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
|
setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
|
||||||
|
|
||||||
# @rdname tojson
|
# @rdname tojson
|
||||||
# @export
|
# @export
|
||||||
|
@ -564,6 +591,14 @@ setGeneric("like", function(x, ...) { standardGeneric("like") })
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("lower", function(x) { standardGeneric("lower") })
|
setGeneric("lower", function(x) { standardGeneric("lower") })
|
||||||
|
|
||||||
|
#' @rdname column
|
||||||
|
#' @export
|
||||||
|
setGeneric("n", function(x) { standardGeneric("n") })
|
||||||
|
|
||||||
|
#' @rdname column
|
||||||
|
#' @export
|
||||||
|
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
|
||||||
|
|
||||||
#' @rdname column
|
#' @rdname column
|
||||||
#' @export
|
#' @export
|
||||||
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
|
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
|
||||||
|
|
|
@ -56,6 +56,7 @@ setMethod("show", "GroupedData",
|
||||||
#'
|
#'
|
||||||
#' @param x a GroupedData
|
#' @param x a GroupedData
|
||||||
#' @return a DataFrame
|
#' @return a DataFrame
|
||||||
|
#' @rdname agg
|
||||||
#' @export
|
#' @export
|
||||||
#' @examples
|
#' @examples
|
||||||
#' \dontrun{
|
#' \dontrun{
|
||||||
|
@ -83,8 +84,6 @@ setMethod("count",
|
||||||
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
|
#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
|
||||||
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
|
#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
|
||||||
#' }
|
#' }
|
||||||
setGeneric("agg", function (x, ...) { standardGeneric("agg") })
|
|
||||||
|
|
||||||
setMethod("agg",
|
setMethod("agg",
|
||||||
signature(x = "GroupedData"),
|
signature(x = "GroupedData"),
|
||||||
function(x, ...) {
|
function(x, ...) {
|
||||||
|
@ -112,6 +111,13 @@ setMethod("agg",
|
||||||
dataFrame(sdf)
|
dataFrame(sdf)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#' @rdname agg
|
||||||
|
#' @aliases agg
|
||||||
|
setMethod("summarize",
|
||||||
|
signature(x = "GroupedData"),
|
||||||
|
function(x, ...) {
|
||||||
|
agg(x, ...)
|
||||||
|
})
|
||||||
|
|
||||||
# sum/mean/avg/min/max
|
# sum/mean/avg/min/max
|
||||||
methods <- c("sum", "mean", "avg", "min", "max")
|
methods <- c("sum", "mean", "avg", "min", "max")
|
||||||
|
|
|
@ -428,6 +428,10 @@ test_that("sampleDF on a DataFrame", {
|
||||||
expect_true(inherits(sampled, "DataFrame"))
|
expect_true(inherits(sampled, "DataFrame"))
|
||||||
sampled2 <- sampleDF(df, FALSE, 0.1)
|
sampled2 <- sampleDF(df, FALSE, 0.1)
|
||||||
expect_true(count(sampled2) < 3)
|
expect_true(count(sampled2) < 3)
|
||||||
|
|
||||||
|
# Also test sample_frac
|
||||||
|
sampled3 <- sample_frac(df, FALSE, 0.1)
|
||||||
|
expect_true(count(sampled3) < 3)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("select operators", {
|
test_that("select operators", {
|
||||||
|
@ -533,6 +537,7 @@ test_that("column functions", {
|
||||||
c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
|
c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
|
||||||
c3 <- lower(c) + upper(c) + first(c) + last(c)
|
c3 <- lower(c) + upper(c) + first(c) + last(c)
|
||||||
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
|
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
|
||||||
|
c5 <- n(c) + n_distinct(c)
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("string operators", {
|
test_that("string operators", {
|
||||||
|
@ -557,6 +562,13 @@ test_that("group by", {
|
||||||
expect_true(inherits(df2, "DataFrame"))
|
expect_true(inherits(df2, "DataFrame"))
|
||||||
expect_true(3 == count(df2))
|
expect_true(3 == count(df2))
|
||||||
|
|
||||||
|
# Also test group_by, summarize, mean
|
||||||
|
gd1 <- group_by(df, "name")
|
||||||
|
expect_true(inherits(gd1, "GroupedData"))
|
||||||
|
df_summarized <- summarize(gd, mean_age = mean(df$age))
|
||||||
|
expect_true(inherits(df_summarized, "DataFrame"))
|
||||||
|
expect_true(3 == count(df_summarized))
|
||||||
|
|
||||||
df3 <- agg(gd, age = "sum")
|
df3 <- agg(gd, age = "sum")
|
||||||
expect_true(inherits(df3, "DataFrame"))
|
expect_true(inherits(df3, "DataFrame"))
|
||||||
expect_true(3 == count(df3))
|
expect_true(3 == count(df3))
|
||||||
|
@ -573,12 +585,12 @@ test_that("group by", {
|
||||||
expect_true(3 == count(max(gd, "age")))
|
expect_true(3 == count(max(gd, "age")))
|
||||||
})
|
})
|
||||||
|
|
||||||
test_that("sortDF() and orderBy() on a DataFrame", {
|
test_that("arrange() and orderBy() on a DataFrame", {
|
||||||
df <- jsonFile(sqlCtx, jsonPath)
|
df <- jsonFile(sqlCtx, jsonPath)
|
||||||
sorted <- sortDF(df, df$age)
|
sorted <- arrange(df, df$age)
|
||||||
expect_true(collect(sorted)[1,2] == "Michael")
|
expect_true(collect(sorted)[1,2] == "Michael")
|
||||||
|
|
||||||
sorted2 <- sortDF(df, "name")
|
sorted2 <- arrange(df, "name")
|
||||||
expect_true(collect(sorted2)[2,"age"] == 19)
|
expect_true(collect(sorted2)[2,"age"] == 19)
|
||||||
|
|
||||||
sorted3 <- orderBy(df, asc(df$age))
|
sorted3 <- orderBy(df, asc(df$age))
|
||||||
|
@ -659,17 +671,17 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
|
||||||
writeLines(lines, jsonPath2)
|
writeLines(lines, jsonPath2)
|
||||||
df2 <- loadDF(sqlCtx, jsonPath2, "json")
|
df2 <- loadDF(sqlCtx, jsonPath2, "json")
|
||||||
|
|
||||||
unioned <- sortDF(unionAll(df, df2), df$age)
|
unioned <- arrange(unionAll(df, df2), df$age)
|
||||||
expect_true(inherits(unioned, "DataFrame"))
|
expect_true(inherits(unioned, "DataFrame"))
|
||||||
expect_true(count(unioned) == 6)
|
expect_true(count(unioned) == 6)
|
||||||
expect_true(first(unioned)$name == "Michael")
|
expect_true(first(unioned)$name == "Michael")
|
||||||
|
|
||||||
excepted <- sortDF(except(df, df2), desc(df$age))
|
excepted <- arrange(except(df, df2), desc(df$age))
|
||||||
expect_true(inherits(unioned, "DataFrame"))
|
expect_true(inherits(unioned, "DataFrame"))
|
||||||
expect_true(count(excepted) == 2)
|
expect_true(count(excepted) == 2)
|
||||||
expect_true(first(excepted)$name == "Justin")
|
expect_true(first(excepted)$name == "Justin")
|
||||||
|
|
||||||
intersected <- sortDF(intersect(df, df2), df$age)
|
intersected <- arrange(intersect(df, df2), df$age)
|
||||||
expect_true(inherits(unioned, "DataFrame"))
|
expect_true(inherits(unioned, "DataFrame"))
|
||||||
expect_true(count(intersected) == 1)
|
expect_true(count(intersected) == 1)
|
||||||
expect_true(first(intersected)$name == "Andy")
|
expect_true(first(intersected)$name == "Andy")
|
||||||
|
@ -687,6 +699,18 @@ test_that("withColumn() and withColumnRenamed()", {
|
||||||
expect_true(columns(newDF2)[1] == "newerAge")
|
expect_true(columns(newDF2)[1] == "newerAge")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("mutate() and rename()", {
|
||||||
|
df <- jsonFile(sqlCtx, jsonPath)
|
||||||
|
newDF <- mutate(df, newAge = df$age + 2)
|
||||||
|
expect_true(length(columns(newDF)) == 3)
|
||||||
|
expect_true(columns(newDF)[3] == "newAge")
|
||||||
|
expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
|
||||||
|
|
||||||
|
newDF2 <- rename(df, newerAge = df$age)
|
||||||
|
expect_true(length(columns(newDF2)) == 2)
|
||||||
|
expect_true(columns(newDF2)[1] == "newerAge")
|
||||||
|
})
|
||||||
|
|
||||||
test_that("saveDF() on DataFrame and works with parquetFile", {
|
test_that("saveDF() on DataFrame and works with parquetFile", {
|
||||||
df <- jsonFile(sqlCtx, jsonPath)
|
df <- jsonFile(sqlCtx, jsonPath)
|
||||||
saveDF(df, parquetPath, "parquet", mode="overwrite")
|
saveDF(df, parquetPath, "parquet", mode="overwrite")
|
||||||
|
|
|
@ -246,6 +246,22 @@ object functions {
|
||||||
*/
|
*/
|
||||||
def last(columnName: String): Column = last(Column(columnName))
|
def last(columnName: String): Column = last(Column(columnName))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Aggregate function: returns the average of the values in a group.
|
||||||
|
* Alias for avg.
|
||||||
|
*
|
||||||
|
* @group agg_funcs
|
||||||
|
*/
|
||||||
|
def mean(e: Column): Column = avg(e)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Aggregate function: returns the average of the values in a group.
|
||||||
|
* Alias for avg.
|
||||||
|
*
|
||||||
|
* @group agg_funcs
|
||||||
|
*/
|
||||||
|
def mean(columnName: String): Column = avg(columnName)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Aggregate function: returns the minimum value of the expression in a group.
|
* Aggregate function: returns the minimum value of the expression in a group.
|
||||||
*
|
*
|
||||||
|
|
|
@ -308,6 +308,11 @@ class DataFrameSuite extends QueryTest {
|
||||||
testData2.agg(avg('a)),
|
testData2.agg(avg('a)),
|
||||||
Row(2.0))
|
Row(2.0))
|
||||||
|
|
||||||
|
// Also check mean
|
||||||
|
checkAnswer(
|
||||||
|
testData2.agg(mean('a)),
|
||||||
|
Row(2.0))
|
||||||
|
|
||||||
checkAnswer(
|
checkAnswer(
|
||||||
testData2.agg(avg('a), sumDistinct('a)), // non-partial
|
testData2.agg(avg('a), sumDistinct('a)), // non-partial
|
||||||
Row(2.0, 6.0) :: Nil)
|
Row(2.0, 6.0) :: Nil)
|
||||||
|
|
Loading…
Reference in a new issue