[SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT NULLS ]

### What changes were proposed in this pull request?
The mainstream database support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG`/`NTH_VALUE`/`FIRST_VALUE`/`LAST_VALUE`.
But the current implement of `LEAD`/`LAG` don't support this syntax.

**Oracle**
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/LEAD.html#GUID-0A0481F1-E98F-4535-A739-FCCA8D1B5B77

**Presto**
https://prestodb.io/docs/current/functions/window.html

**Redshift**
https://docs.aws.amazon.com/redshift/latest/dg/r_WF_LEAD.html

**DB2**
https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1513.htm

**Teradata**
https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/GjCT6l7trjkIEjt~7Dhx4w

**Snowflake**
https://docs.snowflake.com/en/sql-reference/functions/lead.html
https://docs.snowflake.com/en/sql-reference/functions/lag.html

### Why are the changes needed?
Support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG` is very useful.

### Does this PR introduce _any_ user-facing change?
'Yes'.

### How was this patch tested?
Jenkins test.

Closes #30387 from beliefer/SPARK-33443.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: beliefer <beliefer@163.com>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
gengjiaan 2020-12-24 08:13:48 +00:00 committed by Wenchen Fan
parent 32d4a2b062
commit 3e9821edfd
5 changed files with 238 additions and 34 deletions

View file

@ -387,8 +387,6 @@ abstract class FrameLessOffsetWindowFunction
override def nullable: Boolean = default == null || default.nullable || input.nullable
override val ignoreNulls = false
override lazy val frame: WindowFrame = fakeFrame
override def checkInputDataTypes(): TypeCheckResult = {
@ -443,9 +441,13 @@ abstract class FrameLessOffsetWindowFunction
since = "2.0.0",
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Lead(input: Expression, offset: Expression, default: Expression)
case class Lead(
input: Expression, offset: Expression, default: Expression, ignoreNulls: Boolean)
extends FrameLessOffsetWindowFunction {
def this(input: Expression, offset: Expression, default: Expression) =
this(input, offset, default, false)
def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))
def this(input: Expression) = this(input, Literal(1))
@ -485,10 +487,14 @@ case class Lead(input: Expression, offset: Expression, default: Expression)
since = "2.0.0",
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Lag(input: Expression, inputOffset: Expression, default: Expression)
case class Lag(
input: Expression, inputOffset: Expression, default: Expression, ignoreNulls: Boolean)
extends FrameLessOffsetWindowFunction {
def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))
def this(input: Expression, inputOffset: Expression, default: Expression) =
this(input, inputOffset, default, false)
def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null))
def this(input: Expression) = this(input, Literal(1))

View file

@ -119,13 +119,21 @@ trait WindowExecBase extends UnaryExecNode {
* [[WindowExpression]]s and factory function for the [[WindowFrameFunction]].
*/
protected lazy val windowFrameExpressionFactoryPairs = {
type FrameKey = (String, FrameType, Expression, Expression)
type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression])
type ExpressionBuffer = mutable.Buffer[Expression]
val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
// Add a function and its function to the map for a given frame.
def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
val key = (tpe, fr.frameType, fr.lower, fr.upper)
val key = fn match {
// This branch is used for Lead/Lag to support ignoring null.
// All window frames move in rows. If there are multiple Leads or Lags acting on a row
// and operating on different input expressions, they should not be moved uniformly
// by row. Therefore, we put these functions in different window frames.
case f: FrameLessOffsetWindowFunction if f.ignoreNulls =>
(tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized))
case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil)
}
val (es, fns) = framedFunctions.getOrElseUpdate(
key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
es += e
@ -183,7 +191,7 @@ trait WindowExecBase extends UnaryExecNode {
// Create the factory to produce WindowFunctionFrame.
val factory = key match {
// Frameless offset Frame
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _) =>
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) =>
target: InternalRow =>
new FrameLessOffsetWindowFunctionFrame(
target,
@ -193,8 +201,9 @@ trait WindowExecBase extends UnaryExecNode {
child.output,
(expressions, schema) =>
MutableProjection.create(expressions, schema),
offset)
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _) =>
offset,
expr.nonEmpty)
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, _) =>
target: InternalRow => {
new UnboundedOffsetWindowFunctionFrame(
target,
@ -206,7 +215,7 @@ trait WindowExecBase extends UnaryExecNode {
MutableProjection.create(expressions, schema),
offset)
}
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _) =>
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, _) =>
target: InternalRow => {
new UnboundedPrecedingOffsetWindowFunctionFrame(
target,
@ -220,13 +229,13 @@ trait WindowExecBase extends UnaryExecNode {
}
// Entire Partition Frame.
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) =>
target: InternalRow => {
new UnboundedWindowFunctionFrame(target, processor)
}
// Growing Frame.
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) =>
target: InternalRow => {
new UnboundedPrecedingWindowFunctionFrame(
target,
@ -235,7 +244,7 @@ trait WindowExecBase extends UnaryExecNode {
}
// Shrinking Frame.
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) =>
target: InternalRow => {
new UnboundedFollowingWindowFunctionFrame(
target,
@ -244,7 +253,7 @@ trait WindowExecBase extends UnaryExecNode {
}
// Moving Frame.
case ("AGGREGATE", frameType, lower, upper) =>
case ("AGGREGATE", frameType, lower, upper, _) =>
target: InternalRow => {
new SlidingWindowFunctionFrame(
target,

View file

@ -97,13 +97,15 @@ abstract class OffsetWindowFunctionFrameBase(
/** Index of the input row currently used for output. */
protected var inputIndex = 0
/** Attributes of the input row currently used for output. */
protected val inputAttrs = inputSchema.map(_.withNullability(true))
/**
* Create the projection used when the offset row exists.
* Please note that this project always respect null input values (like PostgreSQL).
*/
protected val projection = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
expressions.toSeq.map(_.input), inputAttrs)
@ -114,7 +116,6 @@ abstract class OffsetWindowFunctionFrameBase(
/** Create the projection used when the offset row DOES NOT exists. */
protected val fillDefaultValue = {
// Collect the expressions and bind them.
val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
if (e.default == null || e.default.foldable && e.default.eval() == null) {
// The default value is null.
@ -147,31 +148,132 @@ class FrameLessOffsetWindowFunctionFrame(
expressions: Array[OffsetWindowFunction],
inputSchema: Seq[Attribute],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
offset: Int)
offset: Int,
ignoreNulls: Boolean = false)
extends OffsetWindowFunctionFrameBase(
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
/** Holder the UnsafeRow where the input operator by function is not null. */
private var nextSelectedRow = EmptyRow
// The number of rows skipped to get the next UnsafeRow where the input operator by function
// is not null.
private var skippedNonNullCount = 0
/** Create the projection to determine whether input is null. */
private val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)
/** Check if the output value of the first index is null. */
private def nullCheck(row: InternalRow): Boolean = project(row).getBoolean(0)
/** find the offset row whose input is not null */
private def findNextRowWithNonNullInput(): Unit = {
while (skippedNonNullCount < offset && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
if (!nullCheck(r)) {
nextSelectedRow = r
skippedNonNullCount += 1
}
inputIndex += 1
}
}
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
if (ignoreNulls) {
findNextRowWithNonNullInput()
} else {
while (inputIndex < offset) {
if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
}
inputIndex = offset
}
inputIndex = offset
}
private val doWrite = if (ignoreNulls && offset > 0) {
// For illustration, here is one example: the input data contains nine rows,
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
// We use lead(input, 2) with IGNORE NULLS and the process is as follows:
// 1. current row -> null, next selected row -> y, output: y;
// 2. current row -> x, next selected row -> z, output: z;
// 3. current row -> null, next selected row -> z, output: z;
// 4. current row -> null, next selected row -> z, output: z;
// 5. current row -> y, next selected row -> v, output: v;
// 6. current row -> null, next selected row -> v, output: v;
// 7. current row -> z, next selected row -> empty, output: null;
// ... next selected row is empty, all following return null.
(current: InternalRow) =>
if (nextSelectedRow == EmptyRow) {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
} else {
if (nullCheck(current)) {
projection(nextSelectedRow)
} else {
skippedNonNullCount -= 1
findNextRowWithNonNullInput()
if (skippedNonNullCount == offset) {
projection(nextSelectedRow)
} else {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
nextSelectedRow = EmptyRow
}
}
}
} else if (ignoreNulls && offset < 0) {
// For illustration, here is one example: the input data contains nine rows,
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
// We use lag(input, 1) with IGNORE NULLS and the process is as follows:
// 1. current row -> null, next selected row -> empty, output: null;
// 2. current row -> x, next selected row -> empty, output: null;
// 3. current row -> null, next selected row -> x, output: x;
// 4. current row -> null, next selected row -> x, output: x;
// 5. current row -> y, next selected row -> x, output: x;
// 6. current row -> null, next selected row -> y, output: y;
// 7. current row -> z, next selected row -> y, output: y;
// 8. current row -> v, next selected row -> z, output: z;
// 9. current row -> null, next selected row -> v, output: v;
val absOffset = Math.abs(offset)
(current: InternalRow) =>
if (skippedNonNullCount == absOffset) {
nextSelectedRow = EmptyRow
skippedNonNullCount -= 1
while (nextSelectedRow == EmptyRow && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
if (!nullCheck(r)) {
nextSelectedRow = r
}
inputIndex += 1
}
}
if (nextSelectedRow == EmptyRow) {
// Use default values since the offset row whose input value is not null does not exist.
fillDefaultValue(current)
} else {
projection(nextSelectedRow)
}
if (!nullCheck(current)) {
skippedNonNullCount += 1
}
} else {
(current: InternalRow) =>
if (inputIndex >= 0 && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
fillDefaultValue(current)
}
inputIndex += 1
}
override def write(index: Int, current: InternalRow): Unit = {
if (inputIndex >= 0 && inputIndex < input.length) {
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
fillDefaultValue(current)
}
inputIndex += 1
doWrite(current)
}
}

View file

@ -937,8 +937,24 @@ object functions {
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
Lag(e.expr, Literal(offset), Literal(defaultValue))
def lag(e: Column, offset: Int, defaultValue: Any): Column = {
lag(e, offset, defaultValue, false)
}
/**
* Window function: returns the value that is `offset` rows before the current row, and
* `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls`
* determines whether null values of row are included in or eliminated from the calculation.
* For example, an `offset` of one will return the previous row at any given point in the
* window partition.
*
* This is equivalent to the LAG function in SQL.
*
* @group window_funcs
* @since 3.2.0
*/
def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
}
/**
@ -989,8 +1005,24 @@ object functions {
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
Lead(e.expr, Literal(offset), Literal(defaultValue))
def lead(e: Column, offset: Int, defaultValue: Any): Column = {
lead(e, offset, defaultValue, false)
}
/**
* Window function: returns the value that is `offset` rows after the current row, and
* `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls`
* determines whether null values of row are included in or eliminated from the calculation.
* The default value of `ignoreNulls` is false. For example, an `offset` of one will return
* the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
* @group window_funcs
* @since 3.2.0
*/
def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
}
/**

View file

@ -700,6 +700,61 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row("b", 3, null, null, null)))
}
test("lead/lag with ignoreNulls") {
val nullStr: String = null
val df = Seq(
("a", 0, nullStr),
("a", 1, "x"),
("b", 2, nullStr),
("c", 3, nullStr),
("a", 4, "y"),
("b", 5, nullStr),
("a", 6, "z"),
("a", 7, "v"),
("a", 8, nullStr)).
toDF("key", "order", "value")
val window = Window.orderBy($"order")
checkAnswer(
df.select(
$"key",
$"order",
$"value",
lead($"value", 1).over(window),
lead($"value", 2).over(window),
lead($"value", 0, null, true).over(window),
lead($"value", 1, null, true).over(window),
lead($"value", 2, null, true).over(window),
lead($"value", 3, null, true).over(window),
lead(concat($"value", $"key"), 1, null, true).over(window),
lag($"value", 1).over(window),
lag($"value", 2).over(window),
lag($"value", 0, null, true).over(window),
lag($"value", 1, null, true).over(window),
lag($"value", 2, null, true).over(window),
lag($"value", 3, null, true).over(window),
lag(concat($"value", $"key"), 1, null, true).over(window))
.orderBy($"order"),
Seq(
Row("a", 0, null, "x", null, null, "x", "y", "z", "xa",
null, null, null, null, null, null, null),
Row("a", 1, "x", null, null, "x", "y", "z", "v", "ya",
null, null, "x", null, null, null, null),
Row("b", 2, null, null, "y", null, "y", "z", "v", "ya",
"x", null, null, "x", null, null, "xa"),
Row("c", 3, null, "y", null, null, "y", "z", "v", "ya",
null, "x", null, "x", null, null, "xa"),
Row("a", 4, "y", null, "z", "y", "z", "v", null, "za",
null, null, "y", "x", null, null, "xa"),
Row("b", 5, null, "z", "v", null, "z", "v", null, "za",
"y", null, null, "y", "x", null, "ya"),
Row("a", 6, "z", "v", null, "z", "v", null, null, "va",
null, "y", "z", "y", "x", null, "ya"),
Row("a", 7, "v", null, null, "v", null, null, null, null,
"z", null, "v", "z", "y", "x", "za"),
Row("a", 8, null, null, null, null, null, null, null, null,
"v", "z", null, "v", "z", "y", "va")))
}
test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
val src = Seq((0, 3, 5)).toDF("a", "b", "c")
.withColumn("Data", struct("a", "b"))