[SPARK-33443][SQL] LEAD/LAG should support [ IGNORE NULLS | RESPECT NULLS ]
### What changes were proposed in this pull request? The mainstream database support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG`/`NTH_VALUE`/`FIRST_VALUE`/`LAST_VALUE`. But the current implement of `LEAD`/`LAG` don't support this syntax. **Oracle** https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/LEAD.html#GUID-0A0481F1-E98F-4535-A739-FCCA8D1B5B77 **Presto** https://prestodb.io/docs/current/functions/window.html **Redshift** https://docs.aws.amazon.com/redshift/latest/dg/r_WF_LEAD.html **DB2** https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1513.htm **Teradata** https://docs.teradata.com/r/756LNiPSFdY~4JcCCcR5Cw/GjCT6l7trjkIEjt~7Dhx4w **Snowflake** https://docs.snowflake.com/en/sql-reference/functions/lead.html https://docs.snowflake.com/en/sql-reference/functions/lag.html ### Why are the changes needed? Support `[ IGNORE NULLS | RESPECT NULLS ]` for `LEAD`/`LAG` is very useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. ### How was this patch tested? Jenkins test. Closes #30387 from beliefer/SPARK-33443. Lead-authored-by: gengjiaan <gengjiaan@360.cn> Co-authored-by: beliefer <beliefer@163.com> Co-authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
32d4a2b062
commit
3e9821edfd
|
@ -387,8 +387,6 @@ abstract class FrameLessOffsetWindowFunction
|
|||
|
||||
override def nullable: Boolean = default == null || default.nullable || input.nullable
|
||||
|
||||
override val ignoreNulls = false
|
||||
|
||||
override lazy val frame: WindowFrame = fakeFrame
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
|
@ -443,9 +441,13 @@ abstract class FrameLessOffsetWindowFunction
|
|||
since = "2.0.0",
|
||||
group = "window_funcs")
|
||||
// scalastyle:on line.size.limit line.contains.tab
|
||||
case class Lead(input: Expression, offset: Expression, default: Expression)
|
||||
case class Lead(
|
||||
input: Expression, offset: Expression, default: Expression, ignoreNulls: Boolean)
|
||||
extends FrameLessOffsetWindowFunction {
|
||||
|
||||
def this(input: Expression, offset: Expression, default: Expression) =
|
||||
this(input, offset, default, false)
|
||||
|
||||
def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))
|
||||
|
||||
def this(input: Expression) = this(input, Literal(1))
|
||||
|
@ -485,10 +487,14 @@ case class Lead(input: Expression, offset: Expression, default: Expression)
|
|||
since = "2.0.0",
|
||||
group = "window_funcs")
|
||||
// scalastyle:on line.size.limit line.contains.tab
|
||||
case class Lag(input: Expression, inputOffset: Expression, default: Expression)
|
||||
case class Lag(
|
||||
input: Expression, inputOffset: Expression, default: Expression, ignoreNulls: Boolean)
|
||||
extends FrameLessOffsetWindowFunction {
|
||||
|
||||
def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))
|
||||
def this(input: Expression, inputOffset: Expression, default: Expression) =
|
||||
this(input, inputOffset, default, false)
|
||||
|
||||
def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null))
|
||||
|
||||
def this(input: Expression) = this(input, Literal(1))
|
||||
|
||||
|
|
|
@ -119,13 +119,21 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
* [[WindowExpression]]s and factory function for the [[WindowFrameFunction]].
|
||||
*/
|
||||
protected lazy val windowFrameExpressionFactoryPairs = {
|
||||
type FrameKey = (String, FrameType, Expression, Expression)
|
||||
type FrameKey = (String, FrameType, Expression, Expression, Seq[Expression])
|
||||
type ExpressionBuffer = mutable.Buffer[Expression]
|
||||
val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
|
||||
|
||||
// Add a function and its function to the map for a given frame.
|
||||
def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
|
||||
val key = (tpe, fr.frameType, fr.lower, fr.upper)
|
||||
val key = fn match {
|
||||
// This branch is used for Lead/Lag to support ignoring null.
|
||||
// All window frames move in rows. If there are multiple Leads or Lags acting on a row
|
||||
// and operating on different input expressions, they should not be moved uniformly
|
||||
// by row. Therefore, we put these functions in different window frames.
|
||||
case f: FrameLessOffsetWindowFunction if f.ignoreNulls =>
|
||||
(tpe, fr.frameType, fr.lower, fr.upper, f.children.map(_.canonicalized))
|
||||
case _ => (tpe, fr.frameType, fr.lower, fr.upper, Nil)
|
||||
}
|
||||
val (es, fns) = framedFunctions.getOrElseUpdate(
|
||||
key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
|
||||
es += e
|
||||
|
@ -183,7 +191,7 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
// Create the factory to produce WindowFunctionFrame.
|
||||
val factory = key match {
|
||||
// Frameless offset Frame
|
||||
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _) =>
|
||||
case ("FRAME_LESS_OFFSET", _, IntegerLiteral(offset), _, expr) =>
|
||||
target: InternalRow =>
|
||||
new FrameLessOffsetWindowFunctionFrame(
|
||||
target,
|
||||
|
@ -193,8 +201,9 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
child.output,
|
||||
(expressions, schema) =>
|
||||
MutableProjection.create(expressions, schema),
|
||||
offset)
|
||||
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _) =>
|
||||
offset,
|
||||
expr.nonEmpty)
|
||||
case ("UNBOUNDED_OFFSET", _, IntegerLiteral(offset), _, _) =>
|
||||
target: InternalRow => {
|
||||
new UnboundedOffsetWindowFunctionFrame(
|
||||
target,
|
||||
|
@ -206,7 +215,7 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
MutableProjection.create(expressions, schema),
|
||||
offset)
|
||||
}
|
||||
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _) =>
|
||||
case ("UNBOUNDED_PRECEDING_OFFSET", _, IntegerLiteral(offset), _, _) =>
|
||||
target: InternalRow => {
|
||||
new UnboundedPrecedingOffsetWindowFunctionFrame(
|
||||
target,
|
||||
|
@ -220,13 +229,13 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
}
|
||||
|
||||
// Entire Partition Frame.
|
||||
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
|
||||
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing, _) =>
|
||||
target: InternalRow => {
|
||||
new UnboundedWindowFunctionFrame(target, processor)
|
||||
}
|
||||
|
||||
// Growing Frame.
|
||||
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
|
||||
case ("AGGREGATE", frameType, UnboundedPreceding, upper, _) =>
|
||||
target: InternalRow => {
|
||||
new UnboundedPrecedingWindowFunctionFrame(
|
||||
target,
|
||||
|
@ -235,7 +244,7 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
}
|
||||
|
||||
// Shrinking Frame.
|
||||
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
|
||||
case ("AGGREGATE", frameType, lower, UnboundedFollowing, _) =>
|
||||
target: InternalRow => {
|
||||
new UnboundedFollowingWindowFunctionFrame(
|
||||
target,
|
||||
|
@ -244,7 +253,7 @@ trait WindowExecBase extends UnaryExecNode {
|
|||
}
|
||||
|
||||
// Moving Frame.
|
||||
case ("AGGREGATE", frameType, lower, upper) =>
|
||||
case ("AGGREGATE", frameType, lower, upper, _) =>
|
||||
target: InternalRow => {
|
||||
new SlidingWindowFunctionFrame(
|
||||
target,
|
||||
|
|
|
@ -97,13 +97,15 @@ abstract class OffsetWindowFunctionFrameBase(
|
|||
/** Index of the input row currently used for output. */
|
||||
protected var inputIndex = 0
|
||||
|
||||
/** Attributes of the input row currently used for output. */
|
||||
protected val inputAttrs = inputSchema.map(_.withNullability(true))
|
||||
|
||||
/**
|
||||
* Create the projection used when the offset row exists.
|
||||
* Please note that this project always respect null input values (like PostgreSQL).
|
||||
*/
|
||||
protected val projection = {
|
||||
// Collect the expressions and bind them.
|
||||
val inputAttrs = inputSchema.map(_.withNullability(true))
|
||||
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
|
||||
expressions.toSeq.map(_.input), inputAttrs)
|
||||
|
||||
|
@ -114,7 +116,6 @@ abstract class OffsetWindowFunctionFrameBase(
|
|||
/** Create the projection used when the offset row DOES NOT exists. */
|
||||
protected val fillDefaultValue = {
|
||||
// Collect the expressions and bind them.
|
||||
val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
|
||||
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
|
||||
if (e.default == null || e.default.foldable && e.default.eval() == null) {
|
||||
// The default value is null.
|
||||
|
@ -147,31 +148,132 @@ class FrameLessOffsetWindowFunctionFrame(
|
|||
expressions: Array[OffsetWindowFunction],
|
||||
inputSchema: Seq[Attribute],
|
||||
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
|
||||
offset: Int)
|
||||
offset: Int,
|
||||
ignoreNulls: Boolean = false)
|
||||
extends OffsetWindowFunctionFrameBase(
|
||||
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
|
||||
|
||||
/** Holder the UnsafeRow where the input operator by function is not null. */
|
||||
private var nextSelectedRow = EmptyRow
|
||||
|
||||
// The number of rows skipped to get the next UnsafeRow where the input operator by function
|
||||
// is not null.
|
||||
private var skippedNonNullCount = 0
|
||||
|
||||
/** Create the projection to determine whether input is null. */
|
||||
private val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)
|
||||
|
||||
/** Check if the output value of the first index is null. */
|
||||
private def nullCheck(row: InternalRow): Boolean = project(row).getBoolean(0)
|
||||
|
||||
/** find the offset row whose input is not null */
|
||||
private def findNextRowWithNonNullInput(): Unit = {
|
||||
while (skippedNonNullCount < offset && inputIndex < input.length) {
|
||||
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
if (!nullCheck(r)) {
|
||||
nextSelectedRow = r
|
||||
skippedNonNullCount += 1
|
||||
}
|
||||
inputIndex += 1
|
||||
}
|
||||
}
|
||||
|
||||
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
|
||||
input = rows
|
||||
inputIterator = input.generateIterator()
|
||||
// drain the first few rows if offset is larger than zero
|
||||
inputIndex = 0
|
||||
while (inputIndex < offset) {
|
||||
if (inputIterator.hasNext) inputIterator.next()
|
||||
inputIndex += 1
|
||||
if (ignoreNulls) {
|
||||
findNextRowWithNonNullInput()
|
||||
} else {
|
||||
while (inputIndex < offset) {
|
||||
if (inputIterator.hasNext) inputIterator.next()
|
||||
inputIndex += 1
|
||||
}
|
||||
inputIndex = offset
|
||||
}
|
||||
inputIndex = offset
|
||||
}
|
||||
|
||||
private val doWrite = if (ignoreNulls && offset > 0) {
|
||||
// For illustration, here is one example: the input data contains nine rows,
|
||||
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
|
||||
// We use lead(input, 2) with IGNORE NULLS and the process is as follows:
|
||||
// 1. current row -> null, next selected row -> y, output: y;
|
||||
// 2. current row -> x, next selected row -> z, output: z;
|
||||
// 3. current row -> null, next selected row -> z, output: z;
|
||||
// 4. current row -> null, next selected row -> z, output: z;
|
||||
// 5. current row -> y, next selected row -> v, output: v;
|
||||
// 6. current row -> null, next selected row -> v, output: v;
|
||||
// 7. current row -> z, next selected row -> empty, output: null;
|
||||
// ... next selected row is empty, all following return null.
|
||||
(current: InternalRow) =>
|
||||
if (nextSelectedRow == EmptyRow) {
|
||||
// Use default values since the offset row whose input value is not null does not exist.
|
||||
fillDefaultValue(current)
|
||||
} else {
|
||||
if (nullCheck(current)) {
|
||||
projection(nextSelectedRow)
|
||||
} else {
|
||||
skippedNonNullCount -= 1
|
||||
findNextRowWithNonNullInput()
|
||||
if (skippedNonNullCount == offset) {
|
||||
projection(nextSelectedRow)
|
||||
} else {
|
||||
// Use default values since the offset row whose input value is not null does not exist.
|
||||
fillDefaultValue(current)
|
||||
nextSelectedRow = EmptyRow
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (ignoreNulls && offset < 0) {
|
||||
// For illustration, here is one example: the input data contains nine rows,
|
||||
// and the input values of each row are: null, x, null, null, y, null, z, v, null.
|
||||
// We use lag(input, 1) with IGNORE NULLS and the process is as follows:
|
||||
// 1. current row -> null, next selected row -> empty, output: null;
|
||||
// 2. current row -> x, next selected row -> empty, output: null;
|
||||
// 3. current row -> null, next selected row -> x, output: x;
|
||||
// 4. current row -> null, next selected row -> x, output: x;
|
||||
// 5. current row -> y, next selected row -> x, output: x;
|
||||
// 6. current row -> null, next selected row -> y, output: y;
|
||||
// 7. current row -> z, next selected row -> y, output: y;
|
||||
// 8. current row -> v, next selected row -> z, output: z;
|
||||
// 9. current row -> null, next selected row -> v, output: v;
|
||||
val absOffset = Math.abs(offset)
|
||||
(current: InternalRow) =>
|
||||
if (skippedNonNullCount == absOffset) {
|
||||
nextSelectedRow = EmptyRow
|
||||
skippedNonNullCount -= 1
|
||||
while (nextSelectedRow == EmptyRow && inputIndex < input.length) {
|
||||
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
if (!nullCheck(r)) {
|
||||
nextSelectedRow = r
|
||||
}
|
||||
inputIndex += 1
|
||||
}
|
||||
}
|
||||
if (nextSelectedRow == EmptyRow) {
|
||||
// Use default values since the offset row whose input value is not null does not exist.
|
||||
fillDefaultValue(current)
|
||||
} else {
|
||||
projection(nextSelectedRow)
|
||||
}
|
||||
if (!nullCheck(current)) {
|
||||
skippedNonNullCount += 1
|
||||
}
|
||||
} else {
|
||||
(current: InternalRow) =>
|
||||
if (inputIndex >= 0 && inputIndex < input.length) {
|
||||
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
projection(r)
|
||||
} else {
|
||||
// Use default values since the offset row does not exist.
|
||||
fillDefaultValue(current)
|
||||
}
|
||||
inputIndex += 1
|
||||
}
|
||||
|
||||
override def write(index: Int, current: InternalRow): Unit = {
|
||||
if (inputIndex >= 0 && inputIndex < input.length) {
|
||||
val r = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
projection(r)
|
||||
} else {
|
||||
// Use default values since the offset row does not exist.
|
||||
fillDefaultValue(current)
|
||||
}
|
||||
inputIndex += 1
|
||||
doWrite(current)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -937,8 +937,24 @@ object functions {
|
|||
* @group window_funcs
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
|
||||
Lag(e.expr, Literal(offset), Literal(defaultValue))
|
||||
def lag(e: Column, offset: Int, defaultValue: Any): Column = {
|
||||
lag(e, offset, defaultValue, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Window function: returns the value that is `offset` rows before the current row, and
|
||||
* `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls`
|
||||
* determines whether null values of row are included in or eliminated from the calculation.
|
||||
* For example, an `offset` of one will return the previous row at any given point in the
|
||||
* window partition.
|
||||
*
|
||||
* This is equivalent to the LAG function in SQL.
|
||||
*
|
||||
* @group window_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
|
||||
Lag(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -989,8 +1005,24 @@ object functions {
|
|||
* @group window_funcs
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
|
||||
Lead(e.expr, Literal(offset), Literal(defaultValue))
|
||||
def lead(e: Column, offset: Int, defaultValue: Any): Column = {
|
||||
lead(e, offset, defaultValue, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Window function: returns the value that is `offset` rows after the current row, and
|
||||
* `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls`
|
||||
* determines whether null values of row are included in or eliminated from the calculation.
|
||||
* The default value of `ignoreNulls` is false. For example, an `offset` of one will return
|
||||
* the next row at any given point in the window partition.
|
||||
*
|
||||
* This is equivalent to the LEAD function in SQL.
|
||||
*
|
||||
* @group window_funcs
|
||||
* @since 3.2.0
|
||||
*/
|
||||
def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = withExpr {
|
||||
Lead(e.expr, Literal(offset), Literal(defaultValue), ignoreNulls)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -700,6 +700,61 @@ class DataFrameWindowFunctionsSuite extends QueryTest
|
|||
Row("b", 3, null, null, null)))
|
||||
}
|
||||
|
||||
test("lead/lag with ignoreNulls") {
|
||||
val nullStr: String = null
|
||||
val df = Seq(
|
||||
("a", 0, nullStr),
|
||||
("a", 1, "x"),
|
||||
("b", 2, nullStr),
|
||||
("c", 3, nullStr),
|
||||
("a", 4, "y"),
|
||||
("b", 5, nullStr),
|
||||
("a", 6, "z"),
|
||||
("a", 7, "v"),
|
||||
("a", 8, nullStr)).
|
||||
toDF("key", "order", "value")
|
||||
val window = Window.orderBy($"order")
|
||||
checkAnswer(
|
||||
df.select(
|
||||
$"key",
|
||||
$"order",
|
||||
$"value",
|
||||
lead($"value", 1).over(window),
|
||||
lead($"value", 2).over(window),
|
||||
lead($"value", 0, null, true).over(window),
|
||||
lead($"value", 1, null, true).over(window),
|
||||
lead($"value", 2, null, true).over(window),
|
||||
lead($"value", 3, null, true).over(window),
|
||||
lead(concat($"value", $"key"), 1, null, true).over(window),
|
||||
lag($"value", 1).over(window),
|
||||
lag($"value", 2).over(window),
|
||||
lag($"value", 0, null, true).over(window),
|
||||
lag($"value", 1, null, true).over(window),
|
||||
lag($"value", 2, null, true).over(window),
|
||||
lag($"value", 3, null, true).over(window),
|
||||
lag(concat($"value", $"key"), 1, null, true).over(window))
|
||||
.orderBy($"order"),
|
||||
Seq(
|
||||
Row("a", 0, null, "x", null, null, "x", "y", "z", "xa",
|
||||
null, null, null, null, null, null, null),
|
||||
Row("a", 1, "x", null, null, "x", "y", "z", "v", "ya",
|
||||
null, null, "x", null, null, null, null),
|
||||
Row("b", 2, null, null, "y", null, "y", "z", "v", "ya",
|
||||
"x", null, null, "x", null, null, "xa"),
|
||||
Row("c", 3, null, "y", null, null, "y", "z", "v", "ya",
|
||||
null, "x", null, "x", null, null, "xa"),
|
||||
Row("a", 4, "y", null, "z", "y", "z", "v", null, "za",
|
||||
null, null, "y", "x", null, null, "xa"),
|
||||
Row("b", 5, null, "z", "v", null, "z", "v", null, "za",
|
||||
"y", null, null, "y", "x", null, "ya"),
|
||||
Row("a", 6, "z", "v", null, "z", "v", null, null, "va",
|
||||
null, "y", "z", "y", "x", null, "ya"),
|
||||
Row("a", 7, "v", null, null, "v", null, null, null, null,
|
||||
"z", null, "v", "z", "y", "x", "za"),
|
||||
Row("a", 8, null, null, null, null, null, null, null, null,
|
||||
"v", "z", null, "v", "z", "y", "va")))
|
||||
}
|
||||
|
||||
test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
|
||||
val src = Seq((0, 3, 5)).toDF("a", "b", "c")
|
||||
.withColumn("Data", struct("a", "b"))
|
||||
|
|
Loading…
Reference in a new issue