diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 91f169e7ea..f1a333b8e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -314,8 +314,8 @@ abstract class OffsetWindowFunction val offset: Expression /** - * Direction (above = 1/below = -1) of the number of rows between the current row and the row - * where the input expression is evaluated. + * Direction of the number of rows between the current row and the row where the input expression + * is evaluated. */ val direction: SortDirection @@ -327,7 +327,7 @@ abstract class OffsetWindowFunction * both the input and the default expression are foldable, the result is still not foldable due to * the frame. */ - override def foldable: Boolean = input.foldable && (default == null || default.foldable) + override def foldable: Boolean = false override def nullable: Boolean = default == null || default.nullable @@ -353,6 +353,21 @@ abstract class OffsetWindowFunction override def toString: String = s"$prettyName($input, $offset, $default)" } +/** + * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows after the current row. + * @param offset rows to jump ahead in the partition. + * @param default to use when the input value is null or when the offset is larger than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows after the + current row in the window""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -365,6 +380,21 @@ case class Lead(input: Expression, offset: Expression, default: Expression) override val direction = Ascending } +/** + * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows before the current row. + * @param offset rows to jump back in the partition. + * @param default to use when the input value is null or when the offset is smaller than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows before the + current row in the window""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -409,10 +439,31 @@ object SizeBasedWindowFunction { val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() } +/** + * The RowNumber function computes a unique, sequential number to each row, starting with one, + * according to the ordering of rows within the window partition. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential + number to each row, starting with one, according to the ordering of rows within the window + partition.""") case class RowNumber() extends RowNumberLike { override val evaluateExpression = rowNumber } +/** + * The CumeDist function computes the position of a value relative to a all values in the partition. + * The result is the number of rows preceding or equal to the current row in the ordering of the + * partition divided by the total number of rows in the window partition. Any tie values in the + * ordering will evaluate to the same position. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The CUME_DIST() function computes the position of a value relative to a all values + in the partition.""") case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { override def dataType: DataType = DoubleType // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must @@ -421,6 +472,30 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) } +/** + * The NTile function divides the rows for each window partition into 'n' buckets ranging from 1 to + * at most 'n'. Bucket values will differ by at most 1. If the number of rows in the partition does + * not divide evenly into the number of buckets, then the remainder values are distributed one per + * bucket, starting with the first bucket. + * + * The NTile function is particularly useful for the calculation of tertiles, quartiles, deciles and + * other common summary statistics + * + * The function calculates two variables during initialization: The size of a regular bucket, and + * the number of buckets that will have one extra row added to it (when the rows do not evenly fit + * into the number of buckets); both variables are based on the size of the current partition. + * During the calculation process the function keeps track of the current row number, the current + * bucket number, and the row number at which the bucket will change (bucketThreshold). When the + * current row number reaches bucket threshold, the bucket value is increased by one and the the + * threshold is increased by the bucket size (plus one extra if the current bucket is padded). + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param buckets number of buckets to divide the rows in. Default value is 1. + */ +@ExpressionDescription(usage = + """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition into 'n' buckets + ranging from 1 to at most 'n'.""") case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { def this() = this(Literal(1)) @@ -474,6 +549,8 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow * the order of the window in which is processed. For instance, when the value of 'x' changes in a * window ordered by 'x' the rank function also changes. The size of the change of the rank function * is (typically) not dependent on the size of the change in 'x'. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. */ abstract class RankLike extends AggregateWindowFunction { override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) @@ -513,11 +590,41 @@ abstract class RankLike extends AggregateWindowFunction { def withOrder(order: Seq[Expression]): RankLike } +/** + * The Rank function computes the rank of a value in a group of values. The result is one plus the + * number of rows preceding or equal to the current row in the ordering of the partition. Tie values + * will produce gaps in the sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - RANK() computes the rank of a value in a group of values. The result is one plus + the number of rows preceding or equal to the current row in the ordering of the partition. Tie + values will produce gaps in the sequence.""") case class Rank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): Rank = Rank(order) } +/** + * The DenseRank function computes the rank of a value in a group of values. The result is one plus + * the previously assigned rank value. Unlike Rank, DenseRank will not produce gaps in the ranking + * sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of values. The + result is one plus the previously assigned rank value. Unlike Rank, DenseRank will not produce + gaps in the ranking sequence.""") case class DenseRank(children: Seq[Expression]) extends RankLike { def this() = this(Nil) override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) @@ -527,6 +634,23 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { override val initialValues = zero +: orderInit } +/** + * The PercentRank function computes the percentage ranking of a value in a group of values. The + * result the rank of the minus one divided by the total number of rows in the partitiion minus one: + * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. + * + * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of + * row counts in the its numerator. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage ranking of a value + in a group of values.""") case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { def this() = this(Nil) override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index b50d7604e0..3917b9762b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -292,4 +292,24 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row("b", 3, 8, 32), Row("b", 2, 4, 8))) } + + test("null inputs") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) + .toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.select( + $"key", + $"value", + avg(lit(null)).over(window), + sum(lit(null)).over(window)), + Seq( + Row("a", 1, null, null), + Row("a", 1, null, null), + Row("a", 2, null, null), + Row("a", 2, null, null), + Row("b", 4, null, null), + Row("b", 3, null, null), + Row("b", 2, null, null))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index c05dbfd760..ea82b8c459 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -227,4 +227,19 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) // scalastyle:on } + + test("null arguments") { + checkAnswer(sql(""" + |select p_mfgr, p_name, p_size, + |sum(null) over(distribute by p_mfgr sort by p_name) as sum, + |avg(null) over(distribute by p_mfgr sort by p_name) as avg + |from part + """.stripMargin), + sql(""" + |select p_mfgr, p_name, p_size, + |null as sum, + |null as avg + |from part + """.stripMargin)) + } }