[SPARK-34306][SQL][PYTHON][R] Use Snake naming rule across the function APIs
### What changes were proposed in this pull request? This PR completes snake_case rule at functions APIs across the languages, see also SPARK-10621. In more details, this PR: - Adds `count_distinct` in Scala Python, and R, and document that `count_distinct` is encouraged. This was not deprecated because `countDistinct` is pretty commonly used. We could deprecate in the future releases. - (Scala-specific) adds `typedlit` but doesn't deprecate `typedLit` which is arguably commonly used. Likewise, we could deprecate in the future releases. - Deprecates and renames: - `sumDistinct` -> `sum_distinct` - `bitwiseNOT` -> `bitwise_not` - `shiftLeft` -> `shiftleft` (matched with SQL name in `FunctionRegistry`) - `shiftRight` -> `shiftright` (matched with SQL name in `FunctionRegistry`) - `shiftRightUnsigned` -> `shiftrightunsigned` (matched with SQL name in `FunctionRegistry`) - (Scala-specific) `callUDF` -> `call_udf` ### Why are the changes needed? To keep the consistent naming in APIs. ### Does this PR introduce _any_ user-facing change? Yes, it deprecates some APIs and add new renamed APIs as described above. ### How was this patch tested? Unittests were added. Closes #31408 from HyukjinKwon/SPARK-34306. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
9db566a882
commit
30468a9015
|
@ -243,6 +243,7 @@ exportMethods("%<=>%",
|
|||
"base64",
|
||||
"between",
|
||||
"bin",
|
||||
"bitwise_not",
|
||||
"bitwiseNOT",
|
||||
"bround",
|
||||
"cast",
|
||||
|
@ -259,6 +260,7 @@ exportMethods("%<=>%",
|
|||
"cos",
|
||||
"cosh",
|
||||
"count",
|
||||
"count_distinct",
|
||||
"countDistinct",
|
||||
"crc32",
|
||||
"create_array",
|
||||
|
@ -391,8 +393,11 @@ exportMethods("%<=>%",
|
|||
"sha1",
|
||||
"sha2",
|
||||
"shiftLeft",
|
||||
"shiftleft",
|
||||
"shiftRight",
|
||||
"shiftright",
|
||||
"shiftRightUnsigned",
|
||||
"shiftrightunsigned",
|
||||
"shuffle",
|
||||
"sd",
|
||||
"sign",
|
||||
|
@ -415,6 +420,7 @@ exportMethods("%<=>%",
|
|||
"substr",
|
||||
"substring_index",
|
||||
"sum",
|
||||
"sum_distinct",
|
||||
"sumDistinct",
|
||||
"tan",
|
||||
"tanh",
|
||||
|
|
|
@ -484,7 +484,7 @@ setMethod("acosh",
|
|||
#' \dontrun{
|
||||
#' head(select(df, approx_count_distinct(df$gear)))
|
||||
#' head(select(df, approx_count_distinct(df$gear, 0.02)))
|
||||
#' head(select(df, countDistinct(df$gear, df$cyl)))
|
||||
#' head(select(df, count_distinct(df$gear, df$cyl)))
|
||||
#' head(select(df, n_distinct(df$gear)))
|
||||
#' head(distinct(select(df, "gear")))}
|
||||
#' @note approx_count_distinct(Column) since 3.0.0
|
||||
|
@ -635,21 +635,34 @@ setMethod("bin",
|
|||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{bitwise_not}: Computes bitwise NOT.
|
||||
#'
|
||||
#' @rdname column_nonaggregate_functions
|
||||
#' @aliases bitwise_not bitwise_not,Column-method
|
||||
#' @examples
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' head(select(df, bitwise_not(cast(df$vs, "int"))))}
|
||||
#' @note bitwise_not since 3.2.0
|
||||
setMethod("bitwise_not",
|
||||
signature(x = "Column"),
|
||||
function(x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "bitwise_not", x@jc)
|
||||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{bitwiseNOT}: Computes bitwise NOT.
|
||||
#'
|
||||
#' @rdname column_nonaggregate_functions
|
||||
#' @aliases bitwiseNOT bitwiseNOT,Column-method
|
||||
#' @examples
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))}
|
||||
#' @note bitwiseNOT since 1.5.0
|
||||
setMethod("bitwiseNOT",
|
||||
signature(x = "Column"),
|
||||
function(x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc)
|
||||
column(jc)
|
||||
.Deprecated("bitwise_not")
|
||||
bitwise_not(x)
|
||||
})
|
||||
|
||||
#' @details
|
||||
|
@ -1936,22 +1949,35 @@ setMethod("sum",
|
|||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{sum_distinct}: Returns the sum of distinct values in the expression.
|
||||
#'
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @aliases sum_distinct sum_distinct,Column-method
|
||||
#' @examples
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' head(select(df, sum_distinct(df$gear)))
|
||||
#' head(distinct(select(df, "gear")))}
|
||||
#' @note sum_distinct since 3.2.0
|
||||
setMethod("sum_distinct",
|
||||
signature(x = "Column"),
|
||||
function(x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "sum_distinct", x@jc)
|
||||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{sumDistinct}: Returns the sum of distinct values in the expression.
|
||||
#'
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @aliases sumDistinct sumDistinct,Column-method
|
||||
#' @examples
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' head(select(df, sumDistinct(df$gear)))
|
||||
#' head(distinct(select(df, "gear")))}
|
||||
#' @note sumDistinct since 1.4.0
|
||||
setMethod("sumDistinct",
|
||||
signature(x = "Column"),
|
||||
function(x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc)
|
||||
column(jc)
|
||||
.Deprecated("sum_distinct")
|
||||
sum_distinct(x)
|
||||
})
|
||||
|
||||
#' @details
|
||||
|
@ -2468,22 +2494,36 @@ setMethod("approxCountDistinct",
|
|||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{count_distinct}: Returns the number of distinct items in a group.
|
||||
#'
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @aliases count_distinct count_distinct,Column-method
|
||||
#' @note count_distinct since 3.2.0
|
||||
setMethod("count_distinct",
|
||||
signature(x = "Column"),
|
||||
function(x, ...) {
|
||||
jcols <- lapply(list(...), function(x) {
|
||||
stopifnot(class(x) == "Column")
|
||||
x@jc
|
||||
})
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "count_distinct", x@jc,
|
||||
jcols)
|
||||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{countDistinct}: Returns the number of distinct items in a group.
|
||||
#'
|
||||
#' An alias of \code{count_distinct}, and it is encouraged to use \code{count_distinct} directly.
|
||||
#'
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @aliases countDistinct countDistinct,Column-method
|
||||
#' @note countDistinct since 1.4.0
|
||||
setMethod("countDistinct",
|
||||
signature(x = "Column"),
|
||||
function(x, ...) {
|
||||
jcols <- lapply(list(...), function(x) {
|
||||
stopifnot(class(x) == "Column")
|
||||
x@jc
|
||||
})
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
|
||||
jcols)
|
||||
column(jc)
|
||||
count_distinct(x, ...)
|
||||
})
|
||||
|
||||
#' @details
|
||||
|
@ -2550,7 +2590,7 @@ setMethod("least",
|
|||
#' @note n_distinct since 1.4.0
|
||||
setMethod("n_distinct", signature(x = "Column"),
|
||||
function(x, ...) {
|
||||
countDistinct(x, ...)
|
||||
count_distinct(x, ...)
|
||||
})
|
||||
|
||||
#' @rdname count
|
||||
|
@ -2893,6 +2933,21 @@ setMethod("sha2", signature(y = "Column", x = "numeric"),
|
|||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{shiftleft}: Shifts the given value numBits left. If the given value is a long value,
|
||||
#' this function will return a long value else it will return an integer value.
|
||||
#'
|
||||
#' @rdname column_math_functions
|
||||
#' @aliases shiftleft shiftleft,Column,numeric-method
|
||||
#' @note shiftleft since 3.2.0
|
||||
setMethod("shiftleft", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions",
|
||||
"shiftleft",
|
||||
y@jc, as.integer(x))
|
||||
column(jc)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value,
|
||||
#' this function will return a long value else it will return an integer value.
|
||||
|
@ -2901,9 +2956,22 @@ setMethod("sha2", signature(y = "Column", x = "numeric"),
|
|||
#' @aliases shiftLeft shiftLeft,Column,numeric-method
|
||||
#' @note shiftLeft since 1.5.0
|
||||
setMethod("shiftLeft", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
.Deprecated("shiftleft")
|
||||
shiftleft(y, x)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{shiftright}: (Signed) shifts the given value numBits right. If the given value is a long
|
||||
#' value, it will return a long value else it will return an integer value.
|
||||
#'
|
||||
#' @rdname column_math_functions
|
||||
#' @aliases shiftright shiftright,Column,numeric-method
|
||||
#' @note shiftright since 3.2.0
|
||||
setMethod("shiftright", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions",
|
||||
"shiftLeft",
|
||||
"shiftright",
|
||||
y@jc, as.integer(x))
|
||||
column(jc)
|
||||
})
|
||||
|
@ -2916,9 +2984,22 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"),
|
|||
#' @aliases shiftRight shiftRight,Column,numeric-method
|
||||
#' @note shiftRight since 1.5.0
|
||||
setMethod("shiftRight", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
.Deprecated("shiftright")
|
||||
shiftright(y, x)
|
||||
})
|
||||
|
||||
#' @details
|
||||
#' \code{shiftrightunsigned}: (Unsigned) shifts the given value numBits right. If the given value is
|
||||
#' a long value, it will return a long value else it will return an integer value.
|
||||
#'
|
||||
#' @rdname column_math_functions
|
||||
#' @aliases shiftrightunsigned shiftrightunsigned,Column,numeric-method
|
||||
#' @note shiftrightunsigned since 3.2.0
|
||||
setMethod("shiftrightunsigned", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions",
|
||||
"shiftRight",
|
||||
"shiftrightunsigned",
|
||||
y@jc, as.integer(x))
|
||||
column(jc)
|
||||
})
|
||||
|
@ -2932,10 +3013,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"),
|
|||
#' @note shiftRightUnsigned since 1.5.0
|
||||
setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
|
||||
function(y, x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions",
|
||||
"shiftRightUnsigned",
|
||||
y@jc, as.integer(x))
|
||||
column(jc)
|
||||
.Deprecated("shiftrightunsigned")
|
||||
shiftrightunsigned(y, x)
|
||||
})
|
||||
|
||||
#' @details
|
||||
|
|
|
@ -884,6 +884,10 @@ setGeneric("base64", function(x) { standardGeneric("base64") })
|
|||
#' @name NULL
|
||||
setGeneric("bin", function(x) { standardGeneric("bin") })
|
||||
|
||||
#' @rdname column_nonaggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("bitwise_not", function(x) { standardGeneric("bitwise_not") })
|
||||
|
||||
#' @rdname column_nonaggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") })
|
||||
|
@ -923,6 +927,10 @@ setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") })
|
|||
#' @name NULL
|
||||
setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") })
|
||||
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("count_distinct", function(x, ...) { standardGeneric("count_distinct") })
|
||||
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") })
|
||||
|
@ -1324,14 +1332,26 @@ setGeneric("sha2", function(y, x) { standardGeneric("sha2") })
|
|||
#' @name NULL
|
||||
setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") })
|
||||
|
||||
#' @rdname column_math_functions
|
||||
#' @name NULL
|
||||
setGeneric("shiftleft", function(y, x) { standardGeneric("shiftleft") })
|
||||
|
||||
#' @rdname column_math_functions
|
||||
#' @name NULL
|
||||
setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") })
|
||||
|
||||
#' @rdname column_math_functions
|
||||
#' @name NULL
|
||||
setGeneric("shiftright", function(y, x) { standardGeneric("shiftright") })
|
||||
|
||||
#' @rdname column_math_functions
|
||||
#' @name NULL
|
||||
setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") })
|
||||
|
||||
#' @rdname column_math_functions
|
||||
#' @name NULL
|
||||
setGeneric("shiftrightunsigned", function(y, x) { standardGeneric("shiftrightunsigned") })
|
||||
|
||||
#' @rdname column_collection_functions
|
||||
#' @name NULL
|
||||
setGeneric("shuffle", function(x) { standardGeneric("shuffle") })
|
||||
|
@ -1388,6 +1408,10 @@ setGeneric("struct", function(x, ...) { standardGeneric("struct") })
|
|||
#' @name NULL
|
||||
setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") })
|
||||
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("sum_distinct", function(x) { standardGeneric("sum_distinct") })
|
||||
|
||||
#' @rdname column_aggregate_functions
|
||||
#' @name NULL
|
||||
setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
|
||||
|
|
|
@ -1397,7 +1397,8 @@ test_that("column operators", {
|
|||
test_that("column functions", {
|
||||
c <- column("a")
|
||||
c1 <- abs(c) + acos(c) + approx_count_distinct(c) + ascii(c) + asin(c) + atan(c)
|
||||
c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c)
|
||||
c2 <- avg(c) + base64(c) + bin(c) + suppressWarnings(bitwiseNOT(c)) +
|
||||
bitwise_not(c) + cbrt(c) + ceil(c) + cos(c)
|
||||
c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c)
|
||||
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
|
||||
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
|
||||
|
@ -1405,7 +1406,8 @@ test_that("column functions", {
|
|||
c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
|
||||
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
|
||||
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
|
||||
c10 <- sumDistinct(c) + tan(c) + tanh(c) + degrees(c) + radians(c)
|
||||
c10 <- suppressWarnings(sumDistinct(c)) + sum_distinct(c) + tan(c) + tanh(c) +
|
||||
degrees(c) + radians(c)
|
||||
c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
|
||||
c12 <- variance(c) + xxhash64(c) + ltrim(c, "a") + rtrim(c, "b") + trim(c, "c")
|
||||
c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
|
||||
|
@ -1457,6 +1459,8 @@ test_that("column functions", {
|
|||
expect_equal(collect(df3)[[2, 1]], FALSE)
|
||||
expect_equal(collect(df3)[[3, 1]], TRUE)
|
||||
|
||||
df4 <- select(df, count_distinct(df$age, df$name))
|
||||
expect_equal(collect(df4)[[1, 1]], 2)
|
||||
df4 <- select(df, countDistinct(df$age, df$name))
|
||||
expect_equal(collect(df4)[[1, 1]], 2)
|
||||
|
||||
|
@ -1887,9 +1891,12 @@ test_that("column binary mathfunctions", {
|
|||
expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
|
||||
expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
|
||||
## nolint end
|
||||
expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16)
|
||||
expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4)
|
||||
expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4)
|
||||
expect_equal(collect(select(df, shiftleft(df$b, 1)))[4, 1], 16)
|
||||
expect_equal(collect(select(df, shiftright(df$b, 1)))[4, 1], 4)
|
||||
expect_equal(collect(select(df, shiftrightunsigned(df$b, 1)))[4, 1], 4)
|
||||
expect_equal(collect(select(df, suppressWarnings(shiftLeft(df$b, 1))))[4, 1], 16)
|
||||
expect_equal(collect(select(df, suppressWarnings(shiftRight(df$b, 1))))[4, 1], 4)
|
||||
expect_equal(collect(select(df, suppressWarnings(shiftRightUnsigned(df$b, 1))))[4, 1], 4)
|
||||
expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric")
|
||||
expect_equal(collect(select(df, rand(1)))[1, 1], 0.636, tolerance = 0.01)
|
||||
expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric")
|
||||
|
|
|
@ -331,7 +331,7 @@ A common flow of grouping and aggregation is
|
|||
|
||||
2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group.
|
||||
|
||||
A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there.
|
||||
A number of widely used functions are supported to aggregate data after grouping, including `avg`, `count_distinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sum_distinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there.
|
||||
|
||||
For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below.
|
||||
|
||||
|
|
|
@ -352,7 +352,7 @@ Scalar functions are functions that return a single value per row, as opposed to
|
|||
|
||||
## Aggregate Functions
|
||||
|
||||
Aggregate functions are functions that return a single value on a group of rows. The [Built-in Aggregation Functions](sql-ref-functions-builtin.html#aggregate-functions) provide common aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc.
|
||||
Aggregate functions are functions that return a single value on a group of rows. The [Built-in Aggregation Functions](sql-ref-functions-builtin.html#aggregate-functions) provide common aggregations such as `count()`, `count_distinct()`, `avg()`, `max()`, `min()`, etc.
|
||||
Users are not limited to the predefined aggregate functions and can create their own. For more details
|
||||
about user defined aggregate functions, please refer to the documentation of
|
||||
[User Defined Aggregate Functions](sql-ref-functions-udf-aggregate.html).
|
||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructField;
|
|||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
// col("...") is preferable to df.col("...")
|
||||
import static org.apache.spark.sql.functions.callUDF;
|
||||
import static org.apache.spark.sql.functions.call_udf;
|
||||
import static org.apache.spark.sql.functions.col;
|
||||
// $example off$
|
||||
|
||||
|
@ -73,12 +73,12 @@ public class JavaTokenizerExample {
|
|||
|
||||
Dataset<Row> tokenized = tokenizer.transform(sentenceDataFrame);
|
||||
tokenized.select("sentence", "words")
|
||||
.withColumn("tokens", callUDF("countTokens", col("words")))
|
||||
.withColumn("tokens", call_udf("countTokens", col("words")))
|
||||
.show(false);
|
||||
|
||||
Dataset<Row> regexTokenized = regexTokenizer.transform(sentenceDataFrame);
|
||||
regexTokenized.select("sentence", "words")
|
||||
.withColumn("tokens", callUDF("countTokens", col("words")))
|
||||
.withColumn("tokens", call_udf("countTokens", col("words")))
|
||||
.show(false);
|
||||
// $example off$
|
||||
|
||||
|
|
|
@ -340,6 +340,7 @@ Functions
|
|||
avg
|
||||
base64
|
||||
bin
|
||||
bitwise_not
|
||||
bitwiseNOT
|
||||
broadcast
|
||||
bround
|
||||
|
@ -358,6 +359,7 @@ Functions
|
|||
cos
|
||||
cosh
|
||||
count
|
||||
count_distinct
|
||||
countDistinct
|
||||
covar_pop
|
||||
covar_samp
|
||||
|
@ -482,9 +484,9 @@ Functions
|
|||
sequence
|
||||
sha1
|
||||
sha2
|
||||
shiftLeft
|
||||
shiftRight
|
||||
shiftRightUnsigned
|
||||
shiftleft
|
||||
shiftright
|
||||
shiftrightunsigned
|
||||
shuffle
|
||||
signum
|
||||
sin
|
||||
|
@ -504,6 +506,7 @@ Functions
|
|||
substring
|
||||
substring_index
|
||||
sum
|
||||
sum_distinct
|
||||
sumDistinct
|
||||
tan
|
||||
tanh
|
||||
|
|
|
@ -206,8 +206,20 @@ def mean(col):
|
|||
def sumDistinct(col):
|
||||
"""
|
||||
Aggregate function: returns the sum of distinct values in the expression.
|
||||
|
||||
.. deprecated:: 3.2.0
|
||||
Use :func:`sum_distinct` instead.
|
||||
"""
|
||||
return _invoke_function_over_column("sumDistinct", col)
|
||||
warnings.warn("Deprecated in 3.2, use sum_distinct instead.", FutureWarning)
|
||||
return sum_distinct(col)
|
||||
|
||||
|
||||
@since(3.2)
|
||||
def sum_distinct(col):
|
||||
"""
|
||||
Aggregate function: returns the sum of distinct values in the expression.
|
||||
"""
|
||||
return _invoke_function_over_column("sum_distinct", col)
|
||||
|
||||
|
||||
def acos(col):
|
||||
|
@ -494,8 +506,20 @@ def toRadians(col):
|
|||
def bitwiseNOT(col):
|
||||
"""
|
||||
Computes bitwise not.
|
||||
|
||||
.. deprecated:: 3.2.0
|
||||
Use :func:`bitwise_not` instead.
|
||||
"""
|
||||
return _invoke_function_over_column("bitwiseNOT", col)
|
||||
warnings.warn("Deprecated in 3.2, use bitwise_not instead.", FutureWarning)
|
||||
return bitwise_not(col)
|
||||
|
||||
|
||||
@since(3.2)
|
||||
def bitwise_not(col):
|
||||
"""
|
||||
Computes bitwise not.
|
||||
"""
|
||||
return _invoke_function_over_column("bitwise_not", col)
|
||||
|
||||
|
||||
@since(2.4)
|
||||
|
@ -810,7 +834,7 @@ def approx_count_distinct(col, rsd=None):
|
|||
col : :class:`Column` or str
|
||||
rsd : float, optional
|
||||
maximum relative standard deviation allowed (default = 0.05).
|
||||
For rsd < 0.01, it is more efficient to use :func:`countDistinct`
|
||||
For rsd < 0.01, it is more efficient to use :func:`count_distinct`
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
@ -928,18 +952,29 @@ def covar_samp(col1, col2):
|
|||
def countDistinct(col, *cols):
|
||||
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
|
||||
|
||||
An alias of :func:`count_distinct`, and it is encouraged to use :func:`count_distinct`
|
||||
directly.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
"""
|
||||
return count_distinct(col, *cols)
|
||||
|
||||
|
||||
def count_distinct(col, *cols):
|
||||
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
|
||||
|
||||
.. versionadded:: 3.2.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
|
||||
>>> df.agg(count_distinct(df.age, df.name).alias('c')).collect()
|
||||
[Row(c=2)]
|
||||
|
||||
>>> df.agg(countDistinct("age", "name").alias('c')).collect()
|
||||
>>> df.agg(count_distinct("age", "name").alias('c')).collect()
|
||||
[Row(c=2)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
|
||||
jc = sc._jvm.functions.count_distinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
|
||||
return Column(jc)
|
||||
|
||||
|
||||
|
@ -1255,13 +1290,25 @@ def shiftLeft(col, numBits):
|
|||
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
.. deprecated:: 3.2.0
|
||||
Use :func:`shiftleft` instead.
|
||||
"""
|
||||
warnings.warn("Deprecated in 3.2, use shiftleft instead.", FutureWarning)
|
||||
return shiftleft(col, numBits)
|
||||
|
||||
|
||||
def shiftleft(col, numBits):
|
||||
"""Shift the given value numBits left.
|
||||
|
||||
.. versionadded:: 3.2.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
|
||||
>>> spark.createDataFrame([(21,)], ['a']).select(shiftleft('a', 1).alias('r')).collect()
|
||||
[Row(r=42)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
|
||||
return Column(sc._jvm.functions.shiftleft(_to_java_column(col), numBits))
|
||||
|
||||
|
||||
def shiftRight(col, numBits):
|
||||
|
@ -1269,9 +1316,21 @@ def shiftRight(col, numBits):
|
|||
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
.. deprecated:: 3.2.0
|
||||
Use :func:`shiftright` instead.
|
||||
"""
|
||||
warnings.warn("Deprecated in 3.2, use shiftright instead.", FutureWarning)
|
||||
return shiftright(col, numBits)
|
||||
|
||||
|
||||
def shiftright(col, numBits):
|
||||
"""(Signed) shift the given value numBits right.
|
||||
|
||||
.. versionadded:: 3.2.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
|
||||
>>> spark.createDataFrame([(42,)], ['a']).select(shiftright('a', 1).alias('r')).collect()
|
||||
[Row(r=21)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
|
@ -1284,10 +1343,22 @@ def shiftRightUnsigned(col, numBits):
|
|||
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
.. deprecated:: 3.2.0
|
||||
Use :func:`shiftrightunsigned` instead.
|
||||
"""
|
||||
warnings.warn("Deprecated in 3.2, use shiftrightunsigned instead.", FutureWarning)
|
||||
return shiftrightunsigned(col, numBits)
|
||||
|
||||
|
||||
def shiftrightunsigned(col, numBits):
|
||||
"""Unsigned shift the given value numBits right.
|
||||
|
||||
.. versionadded:: 3.2.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> df = spark.createDataFrame([(-42,)], ['a'])
|
||||
>>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
|
||||
>>> df.select(shiftrightunsigned('a', 1).alias('r')).collect()
|
||||
[Row(r=9223372036854775787)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
|
|
|
@ -45,6 +45,7 @@ def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
|
|||
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
|
||||
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
|
||||
def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: ...
|
||||
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: ...
|
||||
def first(col: ColumnOrName, ignorenulls: bool = ...) -> Column: ...
|
||||
def grouping(col: ColumnOrName) -> Column: ...
|
||||
def grouping_id(*cols: ColumnOrName) -> Column: ...
|
||||
|
@ -64,8 +65,11 @@ def randn(seed: Optional[int] = ...) -> Column: ...
|
|||
def round(col: ColumnOrName, scale: int = ...) -> Column: ...
|
||||
def bround(col: ColumnOrName, scale: int = ...) -> Column: ...
|
||||
def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftleft(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftRight(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftright(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: ...
|
||||
def spark_partition_id() -> Column: ...
|
||||
def expr(str: str) -> Column: ...
|
||||
def struct(*cols: ColumnOrName) -> Column: ...
|
||||
|
@ -278,6 +282,7 @@ def atan2(col1: ColumnOrName, col2: float) -> Column: ...
|
|||
def avg(col: ColumnOrName) -> Column: ...
|
||||
def base64(col: ColumnOrName) -> Column: ...
|
||||
def bitwiseNOT(col: ColumnOrName) -> Column: ...
|
||||
def bitwise_not(col: ColumnOrName) -> Column: ...
|
||||
def cbrt(col: ColumnOrName) -> Column: ...
|
||||
def ceil(col: ColumnOrName) -> Column: ...
|
||||
def col(col: str) -> Column: ...
|
||||
|
@ -333,6 +338,7 @@ def stddev_pop(col: ColumnOrName) -> Column: ...
|
|||
def stddev_samp(col: ColumnOrName) -> Column: ...
|
||||
def sum(col: ColumnOrName) -> Column: ...
|
||||
def sumDistinct(col: ColumnOrName) -> Column: ...
|
||||
def sum_distinct(col: ColumnOrName) -> Column: ...
|
||||
def tan(col: ColumnOrName) -> Column: ...
|
||||
def tanh(col: ColumnOrName) -> Column: ...
|
||||
def toDegrees(col: ColumnOrName) -> Column: ...
|
||||
|
|
|
@ -139,6 +139,8 @@ class ColumnTests(ReusedSQLTestCase):
|
|||
self.assertEqual(170 ^ 75, result['(a ^ b)'])
|
||||
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
|
||||
self.assertEqual(~75, result['~b'])
|
||||
result = df.select(functions.bitwise_not(df.b)).collect()[0].asDict()
|
||||
self.assertEqual(~75, result['~b'])
|
||||
|
||||
def test_with_field(self):
|
||||
from pyspark.sql.functions import lit, col
|
||||
|
|
|
@ -21,7 +21,9 @@ import re
|
|||
|
||||
from py4j.protocol import Py4JJavaError
|
||||
from pyspark.sql import Row, Window
|
||||
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
|
||||
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \
|
||||
lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \
|
||||
shiftright, shiftrightunsigned, shiftRightUnsigned
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase
|
||||
|
||||
|
||||
|
@ -640,6 +642,23 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
str(cm.exception)
|
||||
)
|
||||
|
||||
def test_sum_distinct(self):
|
||||
self.spark.range(10).select(
|
||||
assert_true(sum_distinct(col("id")) == sumDistinct(col("id")))).collect()
|
||||
|
||||
def test_shiftleft(self):
|
||||
self.spark.range(10).select(
|
||||
assert_true(shiftLeft(col("id"), 2) == shiftleft(col("id"), 2))).collect()
|
||||
|
||||
def test_shiftright(self):
|
||||
self.spark.range(10).select(
|
||||
assert_true(shiftRight(col("id"), 2) == shiftright(col("id"), 2))).collect()
|
||||
|
||||
def test_shiftrightunsigned(self):
|
||||
self.spark.range(10).select(
|
||||
assert_true(
|
||||
shiftRightUnsigned(col("id"), 2) == shiftrightunsigned(col("id"), 2))).collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -31,6 +31,7 @@ class GroupTests(ReusedSQLTestCase):
|
|||
self.assertEqual((0, u'99'),
|
||||
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
|
||||
self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0])
|
||||
# test deprecated countDistinct
|
||||
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
|
||||
|
||||
|
||||
|
|
|
@ -39,8 +39,8 @@ import org.apache.spark.sql.types.IntegerType
|
|||
*
|
||||
* val agg = data.groupBy($"key")
|
||||
* .agg(
|
||||
* countDistinct($"cat1").as("cat1_cnt"),
|
||||
* countDistinct($"cat2").as("cat2_cnt"),
|
||||
* count_distinct($"cat1").as("cat1_cnt"),
|
||||
* count_distinct($"cat2").as("cat2_cnt"),
|
||||
* sum($"value").as("total"))
|
||||
* }}}
|
||||
*
|
||||
|
|
|
@ -2319,7 +2319,7 @@ class Dataset[T] private[sql](
|
|||
*
|
||||
* val allWords = ds.select('title, explode(split('words, " ")).as("word"))
|
||||
*
|
||||
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
|
||||
* val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title"))
|
||||
* }}}
|
||||
*
|
||||
* Using `flatMap()` this can similarly be exploded as:
|
||||
|
|
|
@ -113,6 +113,16 @@ object functions {
|
|||
*/
|
||||
def lit(literal: Any): Column = typedLit(literal)
|
||||
|
||||
/**
|
||||
* Creates a [[Column]] of literal value.
|
||||
*
|
||||
* An alias of `typedlit`, and it is encouraged to use `typedlit` directly.
|
||||
*
|
||||
* @group normal_funcs
|
||||
* @since 2.2.0
|
||||
*/
|
||||
def typedLit[T : TypeTag](literal: T): Column = typedlit(literal)
|
||||
|
||||
/**
|
||||
* Creates a [[Column]] of literal value.
|
||||
*
|
||||
|
@ -123,9 +133,9 @@ object functions {
|
|||
* can handle parameterized scala types e.g.: List, Seq and Map.
|
||||
*
|
||||
* @group normal_funcs
|
||||
* @since 2.2.0
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def typedLit[T : TypeTag](literal: T): Column = literal match {
|
||||
def typedlit[T : TypeTag](literal: T): Column = literal match {
|
||||
case c: Column => c
|
||||
case s: Symbol => new ColumnName(s.name)
|
||||
case _ => Column(Literal.create(literal))
|
||||
|
@ -388,24 +398,37 @@ object functions {
|
|||
/**
|
||||
* Aggregate function: returns the number of distinct items in a group.
|
||||
*
|
||||
* An alias of `count_distinct`, and it is encouraged to use `count_distinct` directly.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def countDistinct(expr: Column, exprs: Column*): Column =
|
||||
// For usage like countDistinct("*"), we should let analyzer expand star and
|
||||
// resolve function.
|
||||
Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true))
|
||||
def countDistinct(expr: Column, exprs: Column*): Column = count_distinct(expr, exprs: _*)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the number of distinct items in a group.
|
||||
*
|
||||
* An alias of `count_distinct`, and it is encouraged to use `count_distinct` directly.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def countDistinct(columnName: String, columnNames: String*): Column =
|
||||
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
|
||||
count_distinct(Column(columnName), columnNames.map(Column.apply) : _*)
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the number of distinct items in a group.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def count_distinct(expr: Column, exprs: Column*): Column =
|
||||
// For usage like countDistinct("*"), we should let analyzer expand star and
|
||||
// resolve function.
|
||||
Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the population covariance for two columns.
|
||||
|
@ -796,6 +819,7 @@ object functions {
|
|||
* @group agg_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
@deprecated("Use sum_distinct", "3.2.0")
|
||||
def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true)
|
||||
|
||||
/**
|
||||
|
@ -804,7 +828,16 @@ object functions {
|
|||
* @group agg_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
|
||||
@deprecated("Use sum_distinct", "3.2.0")
|
||||
def sumDistinct(columnName: String): Column = sum_distinct(Column(columnName))
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the sum of distinct values in the expression.
|
||||
*
|
||||
* @group agg_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def sum_distinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true)
|
||||
|
||||
/**
|
||||
* Aggregate function: alias for `var_samp`.
|
||||
|
@ -1411,7 +1444,16 @@ object functions {
|
|||
* @group normal_funcs
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) }
|
||||
@deprecated("Use bitwise_not", "3.2.0")
|
||||
def bitwiseNOT(e: Column): Column = bitwise_not(e)
|
||||
|
||||
/**
|
||||
* Computes bitwise NOT (~) of a number.
|
||||
*
|
||||
* @group normal_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def bitwise_not(e: Column): Column = withExpr { BitwiseNot(e.expr) }
|
||||
|
||||
/**
|
||||
* Parses the expression string into the column that it represents, similar to
|
||||
|
@ -2142,7 +2184,17 @@ object functions {
|
|||
* @group math_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) }
|
||||
@deprecated("Use shiftleft", "3.2.0")
|
||||
def shiftLeft(e: Column, numBits: Int): Column = shiftleft(e, numBits)
|
||||
|
||||
/**
|
||||
* Shift the given value numBits left. If the given value is a long value, this function
|
||||
* will return a long value else it will return an integer value.
|
||||
*
|
||||
* @group math_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def shiftleft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) }
|
||||
|
||||
/**
|
||||
* (Signed) shift the given value numBits right. If the given value is a long value, it will
|
||||
|
@ -2151,7 +2203,17 @@ object functions {
|
|||
* @group math_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def shiftRight(e: Column, numBits: Int): Column = withExpr {
|
||||
@deprecated("Use shiftright", "3.2.0")
|
||||
def shiftRight(e: Column, numBits: Int): Column = shiftright(e, numBits)
|
||||
|
||||
/**
|
||||
* (Signed) shift the given value numBits right. If the given value is a long value, it will
|
||||
* return a long value else it will return an integer value.
|
||||
*
|
||||
* @group math_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def shiftright(e: Column, numBits: Int): Column = withExpr {
|
||||
ShiftRight(e.expr, lit(numBits).expr)
|
||||
}
|
||||
|
||||
|
@ -2162,7 +2224,17 @@ object functions {
|
|||
* @group math_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr {
|
||||
@deprecated("Use shiftrightunsigned", "3.2.0")
|
||||
def shiftRightUnsigned(e: Column, numBits: Int): Column = shiftrightunsigned(e, numBits)
|
||||
|
||||
/**
|
||||
* Unsigned shift the given value numBits right. If the given value is a long value,
|
||||
* it will return a long value else it will return an integer value.
|
||||
*
|
||||
* @group math_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def shiftrightunsigned(e: Column, numBits: Int): Column = withExpr {
|
||||
ShiftRightUnsigned(e.expr, lit(numBits).expr)
|
||||
}
|
||||
|
||||
|
@ -5073,6 +5145,17 @@ object functions {
|
|||
SparkUserDefinedFunction(f, dataType, inputEncoders = Nil)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call an user-defined function.
|
||||
*
|
||||
* @group udf_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
@deprecated("Use call_udf")
|
||||
def callUDF(udfName: String, cols: Column*): Column =
|
||||
call_udf(udfName, cols: _*)
|
||||
|
||||
/**
|
||||
* Call an user-defined function.
|
||||
* Example:
|
||||
|
@ -5082,14 +5165,14 @@ object functions {
|
|||
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
|
||||
* val spark = df.sparkSession
|
||||
* spark.udf.register("simpleUDF", (v: Int) => v * v)
|
||||
* df.select($"id", callUDF("simpleUDF", $"value"))
|
||||
* df.select($"id", call_udf("simpleUDF", $"value"))
|
||||
* }}}
|
||||
*
|
||||
* @group udf_funcs
|
||||
* @since 1.5.0
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def callUDF(udfName: String, cols: Column*): Column = withExpr {
|
||||
def call_udf(udfName: String, cols: Column*): Column = withExpr {
|
||||
UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ public class JavaDataFrameSuite {
|
|||
|
||||
// Varargs in column expressions
|
||||
df.groupBy().agg(countDistinct("key", "value"));
|
||||
df.groupBy().agg(countDistinct(col("key"), col("value")));
|
||||
df.groupBy().agg(count_distinct(col("key"), col("value")));
|
||||
df.select(coalesce(col("key")));
|
||||
|
||||
// Varargs with mathfunctions
|
||||
|
|
|
@ -297,7 +297,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
Row(2.0, 2.0))
|
||||
|
||||
checkAnswer(
|
||||
testData2.agg(avg($"a"), sumDistinct($"a")), // non-partial
|
||||
testData2.agg(avg($"a"), sumDistinct($"a")), // non-partial and test deprecated version
|
||||
Row(2.0, 6.0) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
|
@ -305,7 +305,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
Row(new java.math.BigDecimal(2)))
|
||||
|
||||
checkAnswer(
|
||||
decimalData.agg(avg($"a"), sumDistinct($"a")), // non-partial
|
||||
decimalData.agg(avg($"a"), sum_distinct($"a")), // non-partial
|
||||
Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
|
@ -314,7 +314,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
// non-partial
|
||||
checkAnswer(
|
||||
decimalData.agg(
|
||||
avg($"a" cast DecimalType(10, 2)), sumDistinct($"a" cast DecimalType(10, 2))),
|
||||
avg($"a" cast DecimalType(10, 2)), sum_distinct($"a" cast DecimalType(10, 2))),
|
||||
Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil)
|
||||
}
|
||||
|
||||
|
@ -324,11 +324,11 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
Row(2.0))
|
||||
|
||||
checkAnswer(
|
||||
testData3.agg(avg($"b"), countDistinct($"b")),
|
||||
testData3.agg(avg($"b"), count_distinct($"b")),
|
||||
Row(2.0, 1))
|
||||
|
||||
checkAnswer(
|
||||
testData3.agg(avg($"b"), sumDistinct($"b")), // non-partial
|
||||
testData3.agg(avg($"b"), sum_distinct($"b")), // non-partial
|
||||
Row(2.0, 2.0))
|
||||
}
|
||||
|
||||
|
@ -339,7 +339,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
Row(null))
|
||||
|
||||
checkAnswer(
|
||||
emptyTableData.agg(avg($"a"), sumDistinct($"b")), // non-partial
|
||||
emptyTableData.agg(avg($"a"), sum_distinct($"b")), // non-partial
|
||||
Row(null, null))
|
||||
}
|
||||
|
||||
|
@ -347,7 +347,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
assert(testData2.count() === testData2.rdd.map(_ => 1).count())
|
||||
|
||||
checkAnswer(
|
||||
testData2.agg(count($"a"), sumDistinct($"a")), // non-partial
|
||||
testData2.agg(count($"a"), sum_distinct($"a")), // non-partial
|
||||
Row(6, 6.0))
|
||||
}
|
||||
|
||||
|
@ -364,12 +364,12 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
|
||||
checkAnswer(
|
||||
testData3.agg(
|
||||
count($"a"), count($"b"), count(lit(1)), countDistinct($"a"), countDistinct($"b")),
|
||||
count($"a"), count($"b"), count(lit(1)), count_distinct($"a"), count_distinct($"b")),
|
||||
Row(2, 1, 2, 2, 1)
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
testData3.agg(count($"b"), countDistinct($"b"), sumDistinct($"b")), // non-partial
|
||||
testData3.agg(count($"b"), count_distinct($"b"), sum_distinct($"b")), // non-partial
|
||||
Row(1, 1, 2)
|
||||
)
|
||||
}
|
||||
|
@ -384,17 +384,17 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
.toDF("key1", "key2", "key3")
|
||||
|
||||
checkAnswer(
|
||||
df1.agg(countDistinct($"key1", $"key2")),
|
||||
df1.agg(count_distinct($"key1", $"key2")),
|
||||
Row(3)
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
df1.agg(countDistinct($"key1", $"key2", $"key3")),
|
||||
df1.agg(count_distinct($"key1", $"key2", $"key3")),
|
||||
Row(3)
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
df1.groupBy($"key1").agg(countDistinct($"key2", $"key3")),
|
||||
df1.groupBy($"key1").agg(count_distinct($"key2", $"key3")),
|
||||
Seq(Row("a", 2), Row("x", 1))
|
||||
)
|
||||
}
|
||||
|
@ -402,7 +402,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
test("zero count") {
|
||||
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
|
||||
checkAnswer(
|
||||
emptyTableData.agg(count($"a"), sumDistinct($"a")), // non-partial
|
||||
emptyTableData.agg(count($"a"), sum_distinct($"a")), // non-partial
|
||||
Row(0, null))
|
||||
}
|
||||
|
||||
|
@ -433,7 +433,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
test("zero sum distinct") {
|
||||
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
|
||||
checkAnswer(
|
||||
emptyTableData.agg(sumDistinct($"a")),
|
||||
emptyTableData.agg(sum_distinct($"a")),
|
||||
Row(null))
|
||||
}
|
||||
|
||||
|
@ -622,7 +622,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d"))
|
||||
.toDF("x", "y", "z")
|
||||
checkAnswer(
|
||||
df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))),
|
||||
df.groupBy($"x").agg(count_distinct($"y"), sort_array(collect_list($"z"))),
|
||||
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
|
||||
}
|
||||
|
||||
|
@ -837,7 +837,7 @@ class DataFrameAggregateSuite extends QueryTest
|
|||
)
|
||||
}
|
||||
|
||||
test("SPARK-27581: DataFrame countDistinct(\"*\") shouldn't fail with AnalysisException") {
|
||||
test("SPARK-27581: DataFrame count_distinct(\"*\") shouldn't fail with AnalysisException") {
|
||||
val df = sql("select id % 100 from range(100000)")
|
||||
val distinctCount1 = df.select(expr("count(distinct(*))"))
|
||||
val distinctCount2 = df.select(countDistinct("*"))
|
||||
|
|
|
@ -171,10 +171,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
)
|
||||
}
|
||||
|
||||
test("bitwiseNOT") {
|
||||
test("bitwise_not") {
|
||||
checkAnswer(
|
||||
testData2.select(bitwiseNOT($"a")),
|
||||
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
|
||||
testData2.select(bitwiseNOT($"a"), bitwise_not($"a")),
|
||||
testData2.collect().toSeq.map(r => Row(~r.getInt(0), ~r.getInt(0))))
|
||||
}
|
||||
|
||||
test("bin") {
|
||||
|
|
|
@ -133,7 +133,7 @@ class DataFrameSuite extends QueryTest
|
|||
df2
|
||||
.select('_1 as 'letter, 'number)
|
||||
.groupBy('letter)
|
||||
.agg(countDistinct('number)),
|
||||
.agg(count_distinct('number)),
|
||||
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
|
||||
)
|
||||
}
|
||||
|
@ -513,7 +513,7 @@ class DataFrameSuite extends QueryTest
|
|||
Row(5, false)))
|
||||
|
||||
checkAnswer(
|
||||
testData2.select(sumDistinct($"a")),
|
||||
testData2.select(sum_distinct($"a")),
|
||||
Row(6))
|
||||
}
|
||||
|
||||
|
@ -607,7 +607,7 @@ class DataFrameSuite extends QueryTest
|
|||
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
|
||||
df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
|
||||
checkAnswer(
|
||||
df.select($"id", callUDF("simpleUDF", $"value")),
|
||||
df.select($"id", callUDF("simpleUDF", $"value")), // test deprecated one
|
||||
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
|
||||
test("function current_timestamp and now") {
|
||||
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
|
||||
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
|
||||
checkAnswer(df1.select(count_distinct(current_timestamp())), Row(1))
|
||||
|
||||
// Execution in one query should return the same value
|
||||
checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true))
|
||||
|
|
|
@ -366,14 +366,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
|
||||
checkAnswer(
|
||||
df.select(
|
||||
shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
|
||||
shiftLeft('f, 1)),
|
||||
shiftleft('a, 1), shiftleft('b, 1), shiftleft('c, 1), shiftleft('d, 1),
|
||||
shiftLeft('f, 1)), // test deprecated one.
|
||||
Row(42.toLong, 42, 42.toShort, 42.toByte, null))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr(
|
||||
"shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
|
||||
"shiftLeft(f, 1)"),
|
||||
"shiftleft(a, 1)", "shiftleft(b, 1)", "shiftleft(b, 1)", "shiftleft(d, 1)",
|
||||
"shiftleft(f, 1)"),
|
||||
Row(42.toLong, 42, 42.toShort, 42.toByte, null))
|
||||
}
|
||||
|
||||
|
@ -383,14 +383,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
|
||||
checkAnswer(
|
||||
df.select(
|
||||
shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
|
||||
shiftRight('f, 1)),
|
||||
shiftright('a, 1), shiftright('b, 1), shiftright('c, 1), shiftright('d, 1),
|
||||
shiftRight('f, 1)), // test deprecated one.
|
||||
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr(
|
||||
"shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
|
||||
"shiftRight(f, 1)"),
|
||||
"shiftright(a, 1)", "shiftright(b, 1)", "shiftright(c, 1)", "shiftright(d, 1)",
|
||||
"shiftright(f, 1)"),
|
||||
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
|
||||
}
|
||||
|
||||
|
@ -400,14 +400,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
|
||||
checkAnswer(
|
||||
df.select(
|
||||
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
|
||||
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
|
||||
shiftrightunsigned('a, 1), shiftrightunsigned('b, 1), shiftrightunsigned('c, 1),
|
||||
shiftrightunsigned('d, 1), shiftRightUnsigned('f, 1)), // test deprecated one.
|
||||
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr(
|
||||
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
|
||||
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
|
||||
"shiftrightunsigned(a, 1)", "shiftrightunsigned(b, 1)", "shiftrightunsigned(c, 1)",
|
||||
"shiftrightunsigned(d, 1)", "shiftrightunsigned(f, 1)"),
|
||||
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
|
||||
}
|
||||
|
||||
|
|
|
@ -61,13 +61,13 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
|
|||
}
|
||||
|
||||
test("count distinct is partially aggregated") {
|
||||
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
|
||||
val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed
|
||||
testPartialAggregationPlan(query)
|
||||
}
|
||||
|
||||
test("mixed aggregates are partially aggregated") {
|
||||
val query =
|
||||
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
|
||||
testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed
|
||||
testPartialAggregationPlan(query)
|
||||
}
|
||||
|
||||
|
|
|
@ -120,8 +120,8 @@ class SameResultSuite extends QueryTest with SharedSparkSession {
|
|||
val df2 = spark.range(10).agg(sum($"id"))
|
||||
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
|
||||
|
||||
val df3 = spark.range(10).agg(sumDistinct($"id"))
|
||||
val df4 = spark.range(10).agg(sumDistinct($"id"))
|
||||
val df3 = spark.range(10).agg(sum_distinct($"id"))
|
||||
val df4 = spark.range(10).agg(sum_distinct($"id"))
|
||||
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ public class JavaDataFrameSuite {
|
|||
udaf.distinct(col("value")),
|
||||
udaf.apply(col("value")),
|
||||
registeredUDAF.apply(col("value")),
|
||||
callUDF("mydoublesum", col("value")));
|
||||
callUDF("mydoublesum", col("value"))); // test deprecated one
|
||||
|
||||
List<Row> expectedResult = new ArrayList<>();
|
||||
expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));
|
||||
|
|
|
@ -711,7 +711,7 @@ object SPARK_9757 extends QueryTest {
|
|||
val df =
|
||||
hiveContext
|
||||
.range(10)
|
||||
.select(callUDF("struct", ($"id" + 0.2) cast DecimalType(10, 3)) as "dec_struct")
|
||||
.select(call_udf("struct", ($"id" + 0.2) cast DecimalType(10, 3)) as "dec_struct")
|
||||
df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
|
||||
checkAnswer(hiveContext.table("t"), df)
|
||||
}
|
||||
|
|
|
@ -219,7 +219,7 @@ class ObjectHashAggregateSuite
|
|||
val withPartialSafe = max($"c2")
|
||||
|
||||
// A Spark SQL native distinct aggregate function
|
||||
val withDistinct = countDistinct($"c3")
|
||||
val withDistinct = count_distinct($"c3")
|
||||
|
||||
val allAggs = Seq(
|
||||
"typed" -> typed,
|
||||
|
|
Loading…
Reference in a new issue