[SPARK-34306][SQL][PYTHON][R] Use Snake naming rule across the function APIs

### What changes were proposed in this pull request?

This PR completes snake_case rule at functions APIs across the languages, see also SPARK-10621.

In more details, this PR:
- Adds `count_distinct` in Scala Python, and R, and document that `count_distinct` is encouraged. This was not deprecated because `countDistinct` is pretty commonly used. We could deprecate in the future releases.
- (Scala-specific) adds `typedlit` but doesn't deprecate `typedLit` which is arguably commonly used. Likewise, we could deprecate in the future releases.
- Deprecates and renames:
  - `sumDistinct` -> `sum_distinct`
  - `bitwiseNOT` -> `bitwise_not`
  - `shiftLeft` -> `shiftleft` (matched with SQL name in `FunctionRegistry`)
  - `shiftRight` -> `shiftright` (matched with SQL name in `FunctionRegistry`)
  - `shiftRightUnsigned` -> `shiftrightunsigned` (matched with SQL name in `FunctionRegistry`)
  - (Scala-specific) `callUDF` -> `call_udf`

### Why are the changes needed?

To keep the consistent naming in APIs.

### Does this PR introduce _any_ user-facing change?

Yes, it deprecates some APIs and add new renamed APIs as described above.

### How was this patch tested?

Unittests were added.

Closes #31408 from HyukjinKwon/SPARK-34306.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
HyukjinKwon 2021-02-02 09:29:40 +09:00
parent 9db566a882
commit 30468a9015
27 changed files with 414 additions and 113 deletions

View file

@ -243,6 +243,7 @@ exportMethods("%<=>%",
"base64",
"between",
"bin",
"bitwise_not",
"bitwiseNOT",
"bround",
"cast",
@ -259,6 +260,7 @@ exportMethods("%<=>%",
"cos",
"cosh",
"count",
"count_distinct",
"countDistinct",
"crc32",
"create_array",
@ -391,8 +393,11 @@ exportMethods("%<=>%",
"sha1",
"sha2",
"shiftLeft",
"shiftleft",
"shiftRight",
"shiftright",
"shiftRightUnsigned",
"shiftrightunsigned",
"shuffle",
"sd",
"sign",
@ -415,6 +420,7 @@ exportMethods("%<=>%",
"substr",
"substring_index",
"sum",
"sum_distinct",
"sumDistinct",
"tan",
"tanh",

View file

@ -484,7 +484,7 @@ setMethod("acosh",
#' \dontrun{
#' head(select(df, approx_count_distinct(df$gear)))
#' head(select(df, approx_count_distinct(df$gear, 0.02)))
#' head(select(df, countDistinct(df$gear, df$cyl)))
#' head(select(df, count_distinct(df$gear, df$cyl)))
#' head(select(df, n_distinct(df$gear)))
#' head(distinct(select(df, "gear")))}
#' @note approx_count_distinct(Column) since 3.0.0
@ -635,21 +635,34 @@ setMethod("bin",
column(jc)
})
#' @details
#' \code{bitwise_not}: Computes bitwise NOT.
#'
#' @rdname column_nonaggregate_functions
#' @aliases bitwise_not bitwise_not,Column-method
#' @examples
#'
#' \dontrun{
#' head(select(df, bitwise_not(cast(df$vs, "int"))))}
#' @note bitwise_not since 3.2.0
setMethod("bitwise_not",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "bitwise_not", x@jc)
column(jc)
})
#' @details
#' \code{bitwiseNOT}: Computes bitwise NOT.
#'
#' @rdname column_nonaggregate_functions
#' @aliases bitwiseNOT bitwiseNOT,Column-method
#' @examples
#'
#' \dontrun{
#' head(select(df, bitwiseNOT(cast(df$vs, "int"))))}
#' @note bitwiseNOT since 1.5.0
setMethod("bitwiseNOT",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc)
column(jc)
.Deprecated("bitwise_not")
bitwise_not(x)
})
#' @details
@ -1936,22 +1949,35 @@ setMethod("sum",
column(jc)
})
#' @details
#' \code{sum_distinct}: Returns the sum of distinct values in the expression.
#'
#' @rdname column_aggregate_functions
#' @aliases sum_distinct sum_distinct,Column-method
#' @examples
#'
#' \dontrun{
#' head(select(df, sum_distinct(df$gear)))
#' head(distinct(select(df, "gear")))}
#' @note sum_distinct since 3.2.0
setMethod("sum_distinct",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "sum_distinct", x@jc)
column(jc)
})
#' @details
#' \code{sumDistinct}: Returns the sum of distinct values in the expression.
#'
#' @rdname column_aggregate_functions
#' @aliases sumDistinct sumDistinct,Column-method
#' @examples
#'
#' \dontrun{
#' head(select(df, sumDistinct(df$gear)))
#' head(distinct(select(df, "gear")))}
#' @note sumDistinct since 1.4.0
setMethod("sumDistinct",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc)
column(jc)
.Deprecated("sum_distinct")
sum_distinct(x)
})
#' @details
@ -2468,22 +2494,36 @@ setMethod("approxCountDistinct",
column(jc)
})
#' @details
#' \code{count_distinct}: Returns the number of distinct items in a group.
#'
#' @rdname column_aggregate_functions
#' @aliases count_distinct count_distinct,Column-method
#' @note count_distinct since 3.2.0
setMethod("count_distinct",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(...), function(x) {
stopifnot(class(x) == "Column")
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "count_distinct", x@jc,
jcols)
column(jc)
})
#' @details
#' \code{countDistinct}: Returns the number of distinct items in a group.
#'
#' An alias of \code{count_distinct}, and it is encouraged to use \code{count_distinct} directly.
#'
#' @rdname column_aggregate_functions
#' @aliases countDistinct countDistinct,Column-method
#' @note countDistinct since 1.4.0
setMethod("countDistinct",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(...), function(x) {
stopifnot(class(x) == "Column")
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
jcols)
column(jc)
count_distinct(x, ...)
})
#' @details
@ -2550,7 +2590,7 @@ setMethod("least",
#' @note n_distinct since 1.4.0
setMethod("n_distinct", signature(x = "Column"),
function(x, ...) {
countDistinct(x, ...)
count_distinct(x, ...)
})
#' @rdname count
@ -2893,6 +2933,21 @@ setMethod("sha2", signature(y = "Column", x = "numeric"),
column(jc)
})
#' @details
#' \code{shiftleft}: Shifts the given value numBits left. If the given value is a long value,
#' this function will return a long value else it will return an integer value.
#'
#' @rdname column_math_functions
#' @aliases shiftleft shiftleft,Column,numeric-method
#' @note shiftleft since 3.2.0
setMethod("shiftleft", signature(y = "Column", x = "numeric"),
function(y, x) {
jc <- callJStatic("org.apache.spark.sql.functions",
"shiftleft",
y@jc, as.integer(x))
column(jc)
})
#' @details
#' \code{shiftLeft}: Shifts the given value numBits left. If the given value is a long value,
#' this function will return a long value else it will return an integer value.
@ -2901,9 +2956,22 @@ setMethod("sha2", signature(y = "Column", x = "numeric"),
#' @aliases shiftLeft shiftLeft,Column,numeric-method
#' @note shiftLeft since 1.5.0
setMethod("shiftLeft", signature(y = "Column", x = "numeric"),
function(y, x) {
.Deprecated("shiftleft")
shiftleft(y, x)
})
#' @details
#' \code{shiftright}: (Signed) shifts the given value numBits right. If the given value is a long
#' value, it will return a long value else it will return an integer value.
#'
#' @rdname column_math_functions
#' @aliases shiftright shiftright,Column,numeric-method
#' @note shiftright since 3.2.0
setMethod("shiftright", signature(y = "Column", x = "numeric"),
function(y, x) {
jc <- callJStatic("org.apache.spark.sql.functions",
"shiftLeft",
"shiftright",
y@jc, as.integer(x))
column(jc)
})
@ -2916,9 +2984,22 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"),
#' @aliases shiftRight shiftRight,Column,numeric-method
#' @note shiftRight since 1.5.0
setMethod("shiftRight", signature(y = "Column", x = "numeric"),
function(y, x) {
.Deprecated("shiftright")
shiftright(y, x)
})
#' @details
#' \code{shiftrightunsigned}: (Unsigned) shifts the given value numBits right. If the given value is
#' a long value, it will return a long value else it will return an integer value.
#'
#' @rdname column_math_functions
#' @aliases shiftrightunsigned shiftrightunsigned,Column,numeric-method
#' @note shiftrightunsigned since 3.2.0
setMethod("shiftrightunsigned", signature(y = "Column", x = "numeric"),
function(y, x) {
jc <- callJStatic("org.apache.spark.sql.functions",
"shiftRight",
"shiftrightunsigned",
y@jc, as.integer(x))
column(jc)
})
@ -2932,10 +3013,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"),
#' @note shiftRightUnsigned since 1.5.0
setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
function(y, x) {
jc <- callJStatic("org.apache.spark.sql.functions",
"shiftRightUnsigned",
y@jc, as.integer(x))
column(jc)
.Deprecated("shiftrightunsigned")
shiftrightunsigned(y, x)
})
#' @details

View file

@ -884,6 +884,10 @@ setGeneric("base64", function(x) { standardGeneric("base64") })
#' @name NULL
setGeneric("bin", function(x) { standardGeneric("bin") })
#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("bitwise_not", function(x) { standardGeneric("bitwise_not") })
#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") })
@ -923,6 +927,10 @@ setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") })
#' @name NULL
setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") })
#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("count_distinct", function(x, ...) { standardGeneric("count_distinct") })
#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") })
@ -1324,14 +1332,26 @@ setGeneric("sha2", function(y, x) { standardGeneric("sha2") })
#' @name NULL
setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") })
#' @rdname column_math_functions
#' @name NULL
setGeneric("shiftleft", function(y, x) { standardGeneric("shiftleft") })
#' @rdname column_math_functions
#' @name NULL
setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") })
#' @rdname column_math_functions
#' @name NULL
setGeneric("shiftright", function(y, x) { standardGeneric("shiftright") })
#' @rdname column_math_functions
#' @name NULL
setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") })
#' @rdname column_math_functions
#' @name NULL
setGeneric("shiftrightunsigned", function(y, x) { standardGeneric("shiftrightunsigned") })
#' @rdname column_collection_functions
#' @name NULL
setGeneric("shuffle", function(x) { standardGeneric("shuffle") })
@ -1388,6 +1408,10 @@ setGeneric("struct", function(x, ...) { standardGeneric("struct") })
#' @name NULL
setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") })
#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("sum_distinct", function(x) { standardGeneric("sum_distinct") })
#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })

View file

@ -1397,7 +1397,8 @@ test_that("column operators", {
test_that("column functions", {
c <- column("a")
c1 <- abs(c) + acos(c) + approx_count_distinct(c) + ascii(c) + asin(c) + atan(c)
c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c)
c2 <- avg(c) + base64(c) + bin(c) + suppressWarnings(bitwiseNOT(c)) +
bitwise_not(c) + cbrt(c) + ceil(c) + cos(c)
c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c)
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
@ -1405,7 +1406,8 @@ test_that("column functions", {
c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + degrees(c) + radians(c)
c10 <- suppressWarnings(sumDistinct(c)) + sum_distinct(c) + tan(c) + tanh(c) +
degrees(c) + radians(c)
c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
c12 <- variance(c) + xxhash64(c) + ltrim(c, "a") + rtrim(c, "b") + trim(c, "c")
c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
@ -1457,6 +1459,8 @@ test_that("column functions", {
expect_equal(collect(df3)[[2, 1]], FALSE)
expect_equal(collect(df3)[[3, 1]], TRUE)
df4 <- select(df, count_distinct(df$age, df$name))
expect_equal(collect(df4)[[1, 1]], 2)
df4 <- select(df, countDistinct(df$age, df$name))
expect_equal(collect(df4)[[1, 1]], 2)
@ -1887,9 +1891,12 @@ test_that("column binary mathfunctions", {
expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
## nolint end
expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16)
expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4)
expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4)
expect_equal(collect(select(df, shiftleft(df$b, 1)))[4, 1], 16)
expect_equal(collect(select(df, shiftright(df$b, 1)))[4, 1], 4)
expect_equal(collect(select(df, shiftrightunsigned(df$b, 1)))[4, 1], 4)
expect_equal(collect(select(df, suppressWarnings(shiftLeft(df$b, 1))))[4, 1], 16)
expect_equal(collect(select(df, suppressWarnings(shiftRight(df$b, 1))))[4, 1], 4)
expect_equal(collect(select(df, suppressWarnings(shiftRightUnsigned(df$b, 1))))[4, 1], 4)
expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric")
expect_equal(collect(select(df, rand(1)))[1, 1], 0.636, tolerance = 0.01)
expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric")

View file

@ -331,7 +331,7 @@ A common flow of grouping and aggregation is
2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group.
A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there.
A number of widely used functions are supported to aggregate data after grouping, including `avg`, `count_distinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sum_distinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there.
For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below.

View file

@ -352,7 +352,7 @@ Scalar functions are functions that return a single value per row, as opposed to
## Aggregate Functions
Aggregate functions are functions that return a single value on a group of rows. The [Built-in Aggregation Functions](sql-ref-functions-builtin.html#aggregate-functions) provide common aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc.
Aggregate functions are functions that return a single value on a group of rows. The [Built-in Aggregation Functions](sql-ref-functions-builtin.html#aggregate-functions) provide common aggregations such as `count()`, `count_distinct()`, `avg()`, `max()`, `min()`, etc.
Users are not limited to the predefined aggregate functions and can create their own. For more details
about user defined aggregate functions, please refer to the documentation of
[User Defined Aggregate Functions](sql-ref-functions-udf-aggregate.html).

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// col("...") is preferable to df.col("...")
import static org.apache.spark.sql.functions.callUDF;
import static org.apache.spark.sql.functions.call_udf;
import static org.apache.spark.sql.functions.col;
// $example off$
@ -73,12 +73,12 @@ public class JavaTokenizerExample {
Dataset<Row> tokenized = tokenizer.transform(sentenceDataFrame);
tokenized.select("sentence", "words")
.withColumn("tokens", callUDF("countTokens", col("words")))
.withColumn("tokens", call_udf("countTokens", col("words")))
.show(false);
Dataset<Row> regexTokenized = regexTokenizer.transform(sentenceDataFrame);
regexTokenized.select("sentence", "words")
.withColumn("tokens", callUDF("countTokens", col("words")))
.withColumn("tokens", call_udf("countTokens", col("words")))
.show(false);
// $example off$

View file

@ -340,6 +340,7 @@ Functions
avg
base64
bin
bitwise_not
bitwiseNOT
broadcast
bround
@ -358,6 +359,7 @@ Functions
cos
cosh
count
count_distinct
countDistinct
covar_pop
covar_samp
@ -482,9 +484,9 @@ Functions
sequence
sha1
sha2
shiftLeft
shiftRight
shiftRightUnsigned
shiftleft
shiftright
shiftrightunsigned
shuffle
signum
sin
@ -504,6 +506,7 @@ Functions
substring
substring_index
sum
sum_distinct
sumDistinct
tan
tanh

View file

@ -206,8 +206,20 @@ def mean(col):
def sumDistinct(col):
"""
Aggregate function: returns the sum of distinct values in the expression.
.. deprecated:: 3.2.0
Use :func:`sum_distinct` instead.
"""
return _invoke_function_over_column("sumDistinct", col)
warnings.warn("Deprecated in 3.2, use sum_distinct instead.", FutureWarning)
return sum_distinct(col)
@since(3.2)
def sum_distinct(col):
"""
Aggregate function: returns the sum of distinct values in the expression.
"""
return _invoke_function_over_column("sum_distinct", col)
def acos(col):
@ -494,8 +506,20 @@ def toRadians(col):
def bitwiseNOT(col):
"""
Computes bitwise not.
.. deprecated:: 3.2.0
Use :func:`bitwise_not` instead.
"""
return _invoke_function_over_column("bitwiseNOT", col)
warnings.warn("Deprecated in 3.2, use bitwise_not instead.", FutureWarning)
return bitwise_not(col)
@since(3.2)
def bitwise_not(col):
"""
Computes bitwise not.
"""
return _invoke_function_over_column("bitwise_not", col)
@since(2.4)
@ -810,7 +834,7 @@ def approx_count_distinct(col, rsd=None):
col : :class:`Column` or str
rsd : float, optional
maximum relative standard deviation allowed (default = 0.05).
For rsd < 0.01, it is more efficient to use :func:`countDistinct`
For rsd < 0.01, it is more efficient to use :func:`count_distinct`
Examples
--------
@ -928,18 +952,29 @@ def covar_samp(col1, col2):
def countDistinct(col, *cols):
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
An alias of :func:`count_distinct`, and it is encouraged to use :func:`count_distinct`
directly.
.. versionadded:: 1.3.0
"""
return count_distinct(col, *cols)
def count_distinct(col, *cols):
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
.. versionadded:: 3.2.0
Examples
--------
>>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
>>> df.agg(count_distinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
>>> df.agg(countDistinct("age", "name").alias('c')).collect()
>>> df.agg(count_distinct("age", "name").alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
jc = sc._jvm.functions.count_distinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
return Column(jc)
@ -1255,13 +1290,25 @@ def shiftLeft(col, numBits):
.. versionadded:: 1.5.0
.. deprecated:: 3.2.0
Use :func:`shiftleft` instead.
"""
warnings.warn("Deprecated in 3.2, use shiftleft instead.", FutureWarning)
return shiftleft(col, numBits)
def shiftleft(col, numBits):
"""Shift the given value numBits left.
.. versionadded:: 3.2.0
Examples
--------
>>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
>>> spark.createDataFrame([(21,)], ['a']).select(shiftleft('a', 1).alias('r')).collect()
[Row(r=42)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.shiftLeft(_to_java_column(col), numBits))
return Column(sc._jvm.functions.shiftleft(_to_java_column(col), numBits))
def shiftRight(col, numBits):
@ -1269,9 +1316,21 @@ def shiftRight(col, numBits):
.. versionadded:: 1.5.0
.. deprecated:: 3.2.0
Use :func:`shiftright` instead.
"""
warnings.warn("Deprecated in 3.2, use shiftright instead.", FutureWarning)
return shiftright(col, numBits)
def shiftright(col, numBits):
"""(Signed) shift the given value numBits right.
.. versionadded:: 3.2.0
Examples
--------
>>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
>>> spark.createDataFrame([(42,)], ['a']).select(shiftright('a', 1).alias('r')).collect()
[Row(r=21)]
"""
sc = SparkContext._active_spark_context
@ -1284,10 +1343,22 @@ def shiftRightUnsigned(col, numBits):
.. versionadded:: 1.5.0
.. deprecated:: 3.2.0
Use :func:`shiftrightunsigned` instead.
"""
warnings.warn("Deprecated in 3.2, use shiftrightunsigned instead.", FutureWarning)
return shiftrightunsigned(col, numBits)
def shiftrightunsigned(col, numBits):
"""Unsigned shift the given value numBits right.
.. versionadded:: 3.2.0
Examples
--------
>>> df = spark.createDataFrame([(-42,)], ['a'])
>>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect()
>>> df.select(shiftrightunsigned('a', 1).alias('r')).collect()
[Row(r=9223372036854775787)]
"""
sc = SparkContext._active_spark_context

View file

@ -45,6 +45,7 @@ def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: ...
def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: ...
def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: ...
def first(col: ColumnOrName, ignorenulls: bool = ...) -> Column: ...
def grouping(col: ColumnOrName) -> Column: ...
def grouping_id(*cols: ColumnOrName) -> Column: ...
@ -64,8 +65,11 @@ def randn(seed: Optional[int] = ...) -> Column: ...
def round(col: ColumnOrName, scale: int = ...) -> Column: ...
def bround(col: ColumnOrName, scale: int = ...) -> Column: ...
def shiftLeft(col: ColumnOrName, numBits: int) -> Column: ...
def shiftleft(col: ColumnOrName, numBits: int) -> Column: ...
def shiftRight(col: ColumnOrName, numBits: int) -> Column: ...
def shiftright(col: ColumnOrName, numBits: int) -> Column: ...
def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: ...
def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: ...
def spark_partition_id() -> Column: ...
def expr(str: str) -> Column: ...
def struct(*cols: ColumnOrName) -> Column: ...
@ -278,6 +282,7 @@ def atan2(col1: ColumnOrName, col2: float) -> Column: ...
def avg(col: ColumnOrName) -> Column: ...
def base64(col: ColumnOrName) -> Column: ...
def bitwiseNOT(col: ColumnOrName) -> Column: ...
def bitwise_not(col: ColumnOrName) -> Column: ...
def cbrt(col: ColumnOrName) -> Column: ...
def ceil(col: ColumnOrName) -> Column: ...
def col(col: str) -> Column: ...
@ -333,6 +338,7 @@ def stddev_pop(col: ColumnOrName) -> Column: ...
def stddev_samp(col: ColumnOrName) -> Column: ...
def sum(col: ColumnOrName) -> Column: ...
def sumDistinct(col: ColumnOrName) -> Column: ...
def sum_distinct(col: ColumnOrName) -> Column: ...
def tan(col: ColumnOrName) -> Column: ...
def tanh(col: ColumnOrName) -> Column: ...
def toDegrees(col: ColumnOrName) -> Column: ...

View file

@ -139,6 +139,8 @@ class ColumnTests(ReusedSQLTestCase):
self.assertEqual(170 ^ 75, result['(a ^ b)'])
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
self.assertEqual(~75, result['~b'])
result = df.select(functions.bitwise_not(df.b)).collect()[0].asDict()
self.assertEqual(~75, result['~b'])
def test_with_field(self):
from pyspark.sql.functions import lit, col

View file

@ -21,7 +21,9 @@ import re
from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, \
lit, assert_true, sum_distinct, sumDistinct, shiftleft, shiftLeft, shiftRight, \
shiftright, shiftrightunsigned, shiftRightUnsigned
from pyspark.testing.sqlutils import ReusedSQLTestCase
@ -640,6 +642,23 @@ class FunctionsTests(ReusedSQLTestCase):
str(cm.exception)
)
def test_sum_distinct(self):
self.spark.range(10).select(
assert_true(sum_distinct(col("id")) == sumDistinct(col("id")))).collect()
def test_shiftleft(self):
self.spark.range(10).select(
assert_true(shiftLeft(col("id"), 2) == shiftleft(col("id"), 2))).collect()
def test_shiftright(self):
self.spark.range(10).select(
assert_true(shiftRight(col("id"), 2) == shiftright(col("id"), 2))).collect()
def test_shiftrightunsigned(self):
self.spark.range(10).select(
assert_true(
shiftRightUnsigned(col("id"), 2) == shiftrightunsigned(col("id"), 2))).collect()
if __name__ == "__main__":
import unittest

View file

@ -31,6 +31,7 @@ class GroupTests(ReusedSQLTestCase):
self.assertEqual((0, u'99'),
tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0])
# test deprecated countDistinct
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

View file

@ -39,8 +39,8 @@ import org.apache.spark.sql.types.IntegerType
*
* val agg = data.groupBy($"key")
* .agg(
* countDistinct($"cat1").as("cat1_cnt"),
* countDistinct($"cat2").as("cat2_cnt"),
* count_distinct($"cat1").as("cat1_cnt"),
* count_distinct($"cat2").as("cat2_cnt"),
* sum($"value").as("total"))
* }}}
*

View file

@ -2319,7 +2319,7 @@ class Dataset[T] private[sql](
*
* val allWords = ds.select('title, explode(split('words, " ")).as("word"))
*
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
* val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title"))
* }}}
*
* Using `flatMap()` this can similarly be exploded as:

View file

@ -113,6 +113,16 @@ object functions {
*/
def lit(literal: Any): Column = typedLit(literal)
/**
* Creates a [[Column]] of literal value.
*
* An alias of `typedlit`, and it is encouraged to use `typedlit` directly.
*
* @group normal_funcs
* @since 2.2.0
*/
def typedLit[T : TypeTag](literal: T): Column = typedlit(literal)
/**
* Creates a [[Column]] of literal value.
*
@ -123,9 +133,9 @@ object functions {
* can handle parameterized scala types e.g.: List, Seq and Map.
*
* @group normal_funcs
* @since 2.2.0
* @since 3.2.0
*/
def typedLit[T : TypeTag](literal: T): Column = literal match {
def typedlit[T : TypeTag](literal: T): Column = literal match {
case c: Column => c
case s: Symbol => new ColumnName(s.name)
case _ => Column(Literal.create(literal))
@ -388,24 +398,37 @@ object functions {
/**
* Aggregate function: returns the number of distinct items in a group.
*
* An alias of `count_distinct`, and it is encouraged to use `count_distinct` directly.
*
* @group agg_funcs
* @since 1.3.0
*/
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
// For usage like countDistinct("*"), we should let analyzer expand star and
// resolve function.
Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true))
def countDistinct(expr: Column, exprs: Column*): Column = count_distinct(expr, exprs: _*)
/**
* Aggregate function: returns the number of distinct items in a group.
*
* An alias of `count_distinct`, and it is encouraged to use `count_distinct` directly.
*
* @group agg_funcs
* @since 1.3.0
*/
@scala.annotation.varargs
def countDistinct(columnName: String, columnNames: String*): Column =
countDistinct(Column(columnName), columnNames.map(Column.apply) : _*)
count_distinct(Column(columnName), columnNames.map(Column.apply) : _*)
/**
* Aggregate function: returns the number of distinct items in a group.
*
* @group agg_funcs
* @since 3.2.0
*/
@scala.annotation.varargs
def count_distinct(expr: Column, exprs: Column*): Column =
// For usage like countDistinct("*"), we should let analyzer expand star and
// resolve function.
Column(UnresolvedFunction("count", (expr +: exprs).map(_.expr), isDistinct = true))
/**
* Aggregate function: returns the population covariance for two columns.
@ -796,6 +819,7 @@ object functions {
* @group agg_funcs
* @since 1.3.0
*/
@deprecated("Use sum_distinct", "3.2.0")
def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true)
/**
@ -804,7 +828,16 @@ object functions {
* @group agg_funcs
* @since 1.3.0
*/
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
@deprecated("Use sum_distinct", "3.2.0")
def sumDistinct(columnName: String): Column = sum_distinct(Column(columnName))
/**
* Aggregate function: returns the sum of distinct values in the expression.
*
* @group agg_funcs
* @since 3.2.0
*/
def sum_distinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true)
/**
* Aggregate function: alias for `var_samp`.
@ -1411,7 +1444,16 @@ object functions {
* @group normal_funcs
* @since 1.4.0
*/
def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) }
@deprecated("Use bitwise_not", "3.2.0")
def bitwiseNOT(e: Column): Column = bitwise_not(e)
/**
* Computes bitwise NOT (~) of a number.
*
* @group normal_funcs
* @since 3.2.0
*/
def bitwise_not(e: Column): Column = withExpr { BitwiseNot(e.expr) }
/**
* Parses the expression string into the column that it represents, similar to
@ -2142,7 +2184,17 @@ object functions {
* @group math_funcs
* @since 1.5.0
*/
def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) }
@deprecated("Use shiftleft", "3.2.0")
def shiftLeft(e: Column, numBits: Int): Column = shiftleft(e, numBits)
/**
* Shift the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 3.2.0
*/
def shiftleft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) }
/**
* (Signed) shift the given value numBits right. If the given value is a long value, it will
@ -2151,7 +2203,17 @@ object functions {
* @group math_funcs
* @since 1.5.0
*/
def shiftRight(e: Column, numBits: Int): Column = withExpr {
@deprecated("Use shiftright", "3.2.0")
def shiftRight(e: Column, numBits: Int): Column = shiftright(e, numBits)
/**
* (Signed) shift the given value numBits right. If the given value is a long value, it will
* return a long value else it will return an integer value.
*
* @group math_funcs
* @since 3.2.0
*/
def shiftright(e: Column, numBits: Int): Column = withExpr {
ShiftRight(e.expr, lit(numBits).expr)
}
@ -2162,7 +2224,17 @@ object functions {
* @group math_funcs
* @since 1.5.0
*/
def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr {
@deprecated("Use shiftrightunsigned", "3.2.0")
def shiftRightUnsigned(e: Column, numBits: Int): Column = shiftrightunsigned(e, numBits)
/**
* Unsigned shift the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
*
* @group math_funcs
* @since 3.2.0
*/
def shiftrightunsigned(e: Column, numBits: Int): Column = withExpr {
ShiftRightUnsigned(e.expr, lit(numBits).expr)
}
@ -5073,6 +5145,17 @@ object functions {
SparkUserDefinedFunction(f, dataType, inputEncoders = Nil)
}
/**
* Call an user-defined function.
*
* @group udf_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
@deprecated("Use call_udf")
def callUDF(udfName: String, cols: Column*): Column =
call_udf(udfName, cols: _*)
/**
* Call an user-defined function.
* Example:
@ -5082,14 +5165,14 @@ object functions {
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
* val spark = df.sparkSession
* spark.udf.register("simpleUDF", (v: Int) => v * v)
* df.select($"id", callUDF("simpleUDF", $"value"))
* df.select($"id", call_udf("simpleUDF", $"value"))
* }}}
*
* @group udf_funcs
* @since 1.5.0
* @since 3.2.0
*/
@scala.annotation.varargs
def callUDF(udfName: String, cols: Column*): Column = withExpr {
def call_udf(udfName: String, cols: Column*): Column = withExpr {
UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
}
}

View file

@ -105,7 +105,7 @@ public class JavaDataFrameSuite {
// Varargs in column expressions
df.groupBy().agg(countDistinct("key", "value"));
df.groupBy().agg(countDistinct(col("key"), col("value")));
df.groupBy().agg(count_distinct(col("key"), col("value")));
df.select(coalesce(col("key")));
// Varargs with mathfunctions

View file

@ -297,7 +297,7 @@ class DataFrameAggregateSuite extends QueryTest
Row(2.0, 2.0))
checkAnswer(
testData2.agg(avg($"a"), sumDistinct($"a")), // non-partial
testData2.agg(avg($"a"), sumDistinct($"a")), // non-partial and test deprecated version
Row(2.0, 6.0) :: Nil)
checkAnswer(
@ -305,7 +305,7 @@ class DataFrameAggregateSuite extends QueryTest
Row(new java.math.BigDecimal(2)))
checkAnswer(
decimalData.agg(avg($"a"), sumDistinct($"a")), // non-partial
decimalData.agg(avg($"a"), sum_distinct($"a")), // non-partial
Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
@ -314,7 +314,7 @@ class DataFrameAggregateSuite extends QueryTest
// non-partial
checkAnswer(
decimalData.agg(
avg($"a" cast DecimalType(10, 2)), sumDistinct($"a" cast DecimalType(10, 2))),
avg($"a" cast DecimalType(10, 2)), sum_distinct($"a" cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil)
}
@ -324,11 +324,11 @@ class DataFrameAggregateSuite extends QueryTest
Row(2.0))
checkAnswer(
testData3.agg(avg($"b"), countDistinct($"b")),
testData3.agg(avg($"b"), count_distinct($"b")),
Row(2.0, 1))
checkAnswer(
testData3.agg(avg($"b"), sumDistinct($"b")), // non-partial
testData3.agg(avg($"b"), sum_distinct($"b")), // non-partial
Row(2.0, 2.0))
}
@ -339,7 +339,7 @@ class DataFrameAggregateSuite extends QueryTest
Row(null))
checkAnswer(
emptyTableData.agg(avg($"a"), sumDistinct($"b")), // non-partial
emptyTableData.agg(avg($"a"), sum_distinct($"b")), // non-partial
Row(null, null))
}
@ -347,7 +347,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(testData2.count() === testData2.rdd.map(_ => 1).count())
checkAnswer(
testData2.agg(count($"a"), sumDistinct($"a")), // non-partial
testData2.agg(count($"a"), sum_distinct($"a")), // non-partial
Row(6, 6.0))
}
@ -364,12 +364,12 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(
testData3.agg(
count($"a"), count($"b"), count(lit(1)), countDistinct($"a"), countDistinct($"b")),
count($"a"), count($"b"), count(lit(1)), count_distinct($"a"), count_distinct($"b")),
Row(2, 1, 2, 2, 1)
)
checkAnswer(
testData3.agg(count($"b"), countDistinct($"b"), sumDistinct($"b")), // non-partial
testData3.agg(count($"b"), count_distinct($"b"), sum_distinct($"b")), // non-partial
Row(1, 1, 2)
)
}
@ -384,17 +384,17 @@ class DataFrameAggregateSuite extends QueryTest
.toDF("key1", "key2", "key3")
checkAnswer(
df1.agg(countDistinct($"key1", $"key2")),
df1.agg(count_distinct($"key1", $"key2")),
Row(3)
)
checkAnswer(
df1.agg(countDistinct($"key1", $"key2", $"key3")),
df1.agg(count_distinct($"key1", $"key2", $"key3")),
Row(3)
)
checkAnswer(
df1.groupBy($"key1").agg(countDistinct($"key2", $"key3")),
df1.groupBy($"key1").agg(count_distinct($"key2", $"key3")),
Seq(Row("a", 2), Row("x", 1))
)
}
@ -402,7 +402,7 @@ class DataFrameAggregateSuite extends QueryTest
test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(count($"a"), sumDistinct($"a")), // non-partial
emptyTableData.agg(count($"a"), sum_distinct($"a")), // non-partial
Row(0, null))
}
@ -433,7 +433,7 @@ class DataFrameAggregateSuite extends QueryTest
test("zero sum distinct") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(sumDistinct($"a")),
emptyTableData.agg(sum_distinct($"a")),
Row(null))
}
@ -622,7 +622,7 @@ class DataFrameAggregateSuite extends QueryTest
val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d"))
.toDF("x", "y", "z")
checkAnswer(
df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))),
df.groupBy($"x").agg(count_distinct($"y"), sort_array(collect_list($"z"))),
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
}
@ -837,7 +837,7 @@ class DataFrameAggregateSuite extends QueryTest
)
}
test("SPARK-27581: DataFrame countDistinct(\"*\") shouldn't fail with AnalysisException") {
test("SPARK-27581: DataFrame count_distinct(\"*\") shouldn't fail with AnalysisException") {
val df = sql("select id % 100 from range(100000)")
val distinctCount1 = df.select(expr("count(distinct(*))"))
val distinctCount2 = df.select(countDistinct("*"))

View file

@ -171,10 +171,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
)
}
test("bitwiseNOT") {
test("bitwise_not") {
checkAnswer(
testData2.select(bitwiseNOT($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
testData2.select(bitwiseNOT($"a"), bitwise_not($"a")),
testData2.collect().toSeq.map(r => Row(~r.getInt(0), ~r.getInt(0))))
}
test("bin") {

View file

@ -133,7 +133,7 @@ class DataFrameSuite extends QueryTest
df2
.select('_1 as 'letter, 'number)
.groupBy('letter)
.agg(countDistinct('number)),
.agg(count_distinct('number)),
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
)
}
@ -513,7 +513,7 @@ class DataFrameSuite extends QueryTest
Row(5, false)))
checkAnswer(
testData2.select(sumDistinct($"a")),
testData2.select(sum_distinct($"a")),
Row(6))
}
@ -607,7 +607,7 @@ class DataFrameSuite extends QueryTest
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
checkAnswer(
df.select($"id", callUDF("simpleUDF", $"value")),
df.select($"id", callUDF("simpleUDF", $"value")), // test deprecated one
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
}

View file

@ -47,7 +47,7 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession {
test("function current_timestamp and now") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
checkAnswer(df1.select(count_distinct(current_timestamp())), Row(1))
// Execution in one query should return the same value
checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true))

View file

@ -366,14 +366,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(
df.select(
shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
shiftLeft('f, 1)),
shiftleft('a, 1), shiftleft('b, 1), shiftleft('c, 1), shiftleft('d, 1),
shiftLeft('f, 1)), // test deprecated one.
Row(42.toLong, 42, 42.toShort, 42.toByte, null))
checkAnswer(
df.selectExpr(
"shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
"shiftLeft(f, 1)"),
"shiftleft(a, 1)", "shiftleft(b, 1)", "shiftleft(b, 1)", "shiftleft(d, 1)",
"shiftleft(f, 1)"),
Row(42.toLong, 42, 42.toShort, 42.toByte, null))
}
@ -383,14 +383,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(
df.select(
shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
shiftRight('f, 1)),
shiftright('a, 1), shiftright('b, 1), shiftright('c, 1), shiftright('d, 1),
shiftRight('f, 1)), // test deprecated one.
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
checkAnswer(
df.selectExpr(
"shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
"shiftRight(f, 1)"),
"shiftright(a, 1)", "shiftright(b, 1)", "shiftright(c, 1)", "shiftright(d, 1)",
"shiftright(f, 1)"),
Row(21.toLong, 21, 21.toShort, 21.toByte, null))
}
@ -400,14 +400,14 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(
df.select(
shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
shiftrightunsigned('a, 1), shiftrightunsigned('b, 1), shiftrightunsigned('c, 1),
shiftrightunsigned('d, 1), shiftRightUnsigned('f, 1)), // test deprecated one.
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
checkAnswer(
df.selectExpr(
"shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
"shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
"shiftrightunsigned(a, 1)", "shiftrightunsigned(b, 1)", "shiftrightunsigned(c, 1)",
"shiftrightunsigned(d, 1)", "shiftrightunsigned(f, 1)"),
Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
}

View file

@ -61,13 +61,13 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
}
test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
}

View file

@ -120,8 +120,8 @@ class SameResultSuite extends QueryTest with SharedSparkSession {
val df2 = spark.range(10).agg(sum($"id"))
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
val df3 = spark.range(10).agg(sumDistinct($"id"))
val df4 = spark.range(10).agg(sumDistinct($"id"))
val df3 = spark.range(10).agg(sum_distinct($"id"))
val df4 = spark.range(10).agg(sum_distinct($"id"))
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
}

View file

@ -85,7 +85,7 @@ public class JavaDataFrameSuite {
udaf.distinct(col("value")),
udaf.apply(col("value")),
registeredUDAF.apply(col("value")),
callUDF("mydoublesum", col("value")));
callUDF("mydoublesum", col("value"))); // test deprecated one
List<Row> expectedResult = new ArrayList<>();
expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));

View file

@ -711,7 +711,7 @@ object SPARK_9757 extends QueryTest {
val df =
hiveContext
.range(10)
.select(callUDF("struct", ($"id" + 0.2) cast DecimalType(10, 3)) as "dec_struct")
.select(call_udf("struct", ($"id" + 0.2) cast DecimalType(10, 3)) as "dec_struct")
df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
checkAnswer(hiveContext.table("t"), df)
}

View file

@ -219,7 +219,7 @@ class ObjectHashAggregateSuite
val withPartialSafe = max($"c2")
// A Spark SQL native distinct aggregate function
val withDistinct = countDistinct($"c3")
val withDistinct = count_distinct($"c3")
val allAggs = Seq(
"typed" -> typed,