[SPARK-34227][SQL] WindowFunctionFrame should clear its states during preparation

### What changes were proposed in this pull request?

This PR fixed all `OffsetWindowFunctionFrameBase#prepare` implementations to reset the states, and also add more comments in `WindowFunctionFrame` classdoc to explain why we need to reset states during preparation: `WindowFunctionFrame` instances are reused to process multiple partitions.

### Why are the changes needed?

To fix a correctness bug caused by the new feature "window function with ignore nulls" in the master branch.

### Does this PR introduce _any_ user-facing change?

yes

### How was this patch tested?

new test

Closes #31325 from cloud-fan/bug.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2021-01-26 08:50:14 +00:00
parent 0a1a029622
commit 8dee8a9b7c
2 changed files with 48 additions and 19 deletions

View file

@ -30,6 +30,9 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
* A window function calculates the results of a number of window functions for a window frame.
* Before use a frame must be prepared by passing it all the rows in the current partition. After
* preparation the update method can be called to fill the output rows.
*
* Note: `WindowFunctionFrame` instances are reused during window execution. The `prepare` method
* will be called before processing the next partition, and must reset the states.
*/
abstract class WindowFunctionFrame {
/**
@ -137,6 +140,15 @@ abstract class OffsetWindowFunctionFrameBase(
// is not null.
protected var skippedNonNullCount = 0
// Reset the states by the data of the new partition.
protected def resetStates(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
inputIndex = 0
skippedNonNullCount = 0
nextSelectedRow = EmptyRow
}
/** Create the projection to determine whether input is null. */
protected val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)
@ -179,13 +191,11 @@ class FrameLessOffsetWindowFunctionFrame(
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
resetStates(rows)
if (ignoreNulls) {
findNextRowWithNonNullInput()
} else {
// drain the first few rows if offset is larger than zero
while (inputIndex < offset) {
if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
@ -299,13 +309,10 @@ class UnboundedOffsetWindowFunctionFrame(
assert(offset > 0)
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
if (offset > input.length) {
if (offset > rows.length) {
fillDefaultValue(EmptyRow)
} else {
inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than one
inputIndex = 0
resetStates(rows)
if (ignoreNulls) {
findNextRowWithNonNullInput()
if (nextSelectedRow == EmptyRow) {
@ -316,6 +323,7 @@ class UnboundedOffsetWindowFunctionFrame(
}
} else {
var selectedRow: UnsafeRow = null
// drain the first few rows if offset is larger than one
while (inputIndex < offset) {
selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
@ -353,27 +361,22 @@ class UnboundedPrecedingOffsetWindowFunctionFrame(
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
assert(offset > 0)
var selectedRow: UnsafeRow = null
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than one
inputIndex = 0
resetStates(rows)
if (ignoreNulls) {
findNextRowWithNonNullInput()
selectedRow = nextSelectedRow.asInstanceOf[UnsafeRow]
} else {
// drain the first few rows if offset is larger than one
while (inputIndex < offset) {
selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
}
}
}
override def write(index: Int, current: InternalRow): Unit = {
if (index >= inputIndex - 1 && selectedRow != null) {
projection(selectedRow)
if (index >= inputIndex - 1 && nextSelectedRow != null) {
projection(nextSelectedRow)
} else {
fillDefaultValue(EmptyRow)
}

View file

@ -1044,4 +1044,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest
Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2),
Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2)))
}
test("SPARK-34227: WindowFunctionFrame should clear its states during preparation") {
// This creates a single partition dataframe with 3 records:
// "a", 0, null
// "a", 1, "x"
// "b", 0, null
val df = spark.range(0, 3, 1, 1).select(
when($"id" < 2, lit("a")).otherwise(lit("b")).as("key"),
($"id" % 2).cast("int").as("order"),
when($"id" % 2 === 0, lit(null)).otherwise(lit("x")).as("value"))
val window1 = Window.partitionBy($"key").orderBy($"order")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
val window2 = Window.partitionBy($"key").orderBy($"order")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
checkAnswer(
df.select(
$"key",
$"order",
nth_value($"value", 1, ignoreNulls = true).over(window1),
nth_value($"value", 1, ignoreNulls = true).over(window2)),
Seq(
Row("a", 0, "x", null),
Row("a", 1, "x", "x"),
Row("b", 0, null, null)))
}
}