[SPARK-34227][SQL] WindowFunctionFrame should clear its states during preparation
### What changes were proposed in this pull request? This PR fixed all `OffsetWindowFunctionFrameBase#prepare` implementations to reset the states, and also add more comments in `WindowFunctionFrame` classdoc to explain why we need to reset states during preparation: `WindowFunctionFrame` instances are reused to process multiple partitions. ### Why are the changes needed? To fix a correctness bug caused by the new feature "window function with ignore nulls" in the master branch. ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? new test Closes #31325 from cloud-fan/bug. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
0a1a029622
commit
8dee8a9b7c
|
@ -30,6 +30,9 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
|
|||
* A window function calculates the results of a number of window functions for a window frame.
|
||||
* Before use a frame must be prepared by passing it all the rows in the current partition. After
|
||||
* preparation the update method can be called to fill the output rows.
|
||||
*
|
||||
* Note: `WindowFunctionFrame` instances are reused during window execution. The `prepare` method
|
||||
* will be called before processing the next partition, and must reset the states.
|
||||
*/
|
||||
abstract class WindowFunctionFrame {
|
||||
/**
|
||||
|
@ -137,6 +140,15 @@ abstract class OffsetWindowFunctionFrameBase(
|
|||
// is not null.
|
||||
protected var skippedNonNullCount = 0
|
||||
|
||||
// Reset the states by the data of the new partition.
|
||||
protected def resetStates(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
|
||||
input = rows
|
||||
inputIterator = input.generateIterator()
|
||||
inputIndex = 0
|
||||
skippedNonNullCount = 0
|
||||
nextSelectedRow = EmptyRow
|
||||
}
|
||||
|
||||
/** Create the projection to determine whether input is null. */
|
||||
protected val project = UnsafeProjection.create(Seq(IsNull(expressions.head.input)), inputSchema)
|
||||
|
||||
|
@ -179,13 +191,11 @@ class FrameLessOffsetWindowFunctionFrame(
|
|||
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
|
||||
|
||||
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
|
||||
input = rows
|
||||
inputIterator = input.generateIterator()
|
||||
// drain the first few rows if offset is larger than zero
|
||||
inputIndex = 0
|
||||
resetStates(rows)
|
||||
if (ignoreNulls) {
|
||||
findNextRowWithNonNullInput()
|
||||
} else {
|
||||
// drain the first few rows if offset is larger than zero
|
||||
while (inputIndex < offset) {
|
||||
if (inputIterator.hasNext) inputIterator.next()
|
||||
inputIndex += 1
|
||||
|
@ -299,13 +309,10 @@ class UnboundedOffsetWindowFunctionFrame(
|
|||
assert(offset > 0)
|
||||
|
||||
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
|
||||
input = rows
|
||||
if (offset > input.length) {
|
||||
if (offset > rows.length) {
|
||||
fillDefaultValue(EmptyRow)
|
||||
} else {
|
||||
inputIterator = input.generateIterator()
|
||||
// drain the first few rows if offset is larger than one
|
||||
inputIndex = 0
|
||||
resetStates(rows)
|
||||
if (ignoreNulls) {
|
||||
findNextRowWithNonNullInput()
|
||||
if (nextSelectedRow == EmptyRow) {
|
||||
|
@ -316,6 +323,7 @@ class UnboundedOffsetWindowFunctionFrame(
|
|||
}
|
||||
} else {
|
||||
var selectedRow: UnsafeRow = null
|
||||
// drain the first few rows if offset is larger than one
|
||||
while (inputIndex < offset) {
|
||||
selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
inputIndex += 1
|
||||
|
@ -353,27 +361,22 @@ class UnboundedPrecedingOffsetWindowFunctionFrame(
|
|||
target, ordinal, expressions, inputSchema, newMutableProjection, offset) {
|
||||
assert(offset > 0)
|
||||
|
||||
var selectedRow: UnsafeRow = null
|
||||
|
||||
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
|
||||
input = rows
|
||||
inputIterator = input.generateIterator()
|
||||
// drain the first few rows if offset is larger than one
|
||||
inputIndex = 0
|
||||
resetStates(rows)
|
||||
if (ignoreNulls) {
|
||||
findNextRowWithNonNullInput()
|
||||
selectedRow = nextSelectedRow.asInstanceOf[UnsafeRow]
|
||||
} else {
|
||||
// drain the first few rows if offset is larger than one
|
||||
while (inputIndex < offset) {
|
||||
selectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
nextSelectedRow = WindowFunctionFrame.getNextOrNull(inputIterator)
|
||||
inputIndex += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def write(index: Int, current: InternalRow): Unit = {
|
||||
if (index >= inputIndex - 1 && selectedRow != null) {
|
||||
projection(selectedRow)
|
||||
if (index >= inputIndex - 1 && nextSelectedRow != null) {
|
||||
projection(nextSelectedRow)
|
||||
} else {
|
||||
fillDefaultValue(EmptyRow)
|
||||
}
|
||||
|
|
|
@ -1044,4 +1044,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest
|
|||
Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2),
|
||||
Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2)))
|
||||
}
|
||||
|
||||
test("SPARK-34227: WindowFunctionFrame should clear its states during preparation") {
|
||||
// This creates a single partition dataframe with 3 records:
|
||||
// "a", 0, null
|
||||
// "a", 1, "x"
|
||||
// "b", 0, null
|
||||
val df = spark.range(0, 3, 1, 1).select(
|
||||
when($"id" < 2, lit("a")).otherwise(lit("b")).as("key"),
|
||||
($"id" % 2).cast("int").as("order"),
|
||||
when($"id" % 2 === 0, lit(null)).otherwise(lit("x")).as("value"))
|
||||
|
||||
val window1 = Window.partitionBy($"key").orderBy($"order")
|
||||
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
val window2 = Window.partitionBy($"key").orderBy($"order")
|
||||
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
|
||||
checkAnswer(
|
||||
df.select(
|
||||
$"key",
|
||||
$"order",
|
||||
nth_value($"value", 1, ignoreNulls = true).over(window1),
|
||||
nth_value($"value", 1, ignoreNulls = true).over(window2)),
|
||||
Seq(
|
||||
Row("a", 0, "x", null),
|
||||
Row("a", 1, "x", "x"),
|
||||
Row("b", 0, null, null)))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue