[SPARK-16633][SPARK-16642][SPARK-16721][SQL] Fixes three issues related to lead and lag functions

## What changes were proposed in this pull request?
This PR contains three changes.

First, this PR changes the behavior of lead/lag back to Spark 1.6's behavior, which is described as below:
1. lead/lag respect null input values, which means that if the offset row exists and the input value is null, the result will be null instead of the default value.
2. If the offset row does not exist, the default value will be used.
3. OffsetWindowFunction's nullable setting also considers the nullability of its input (because of the first change).

Second, this PR fixes the evaluation of lead/lag when the input expression is a literal. This fix is a result of the first change. In current master, if a literal is used as the input expression of a lead or lag function, the result will be this literal even if the offset row does not exist.

Third, this PR makes ResolveWindowFrame not fire if a window function is not resolved.

## How was this patch tested?
New tests in SQLWindowFunctionSuite

Author: Yin Huai <yhuai@databricks.com>

Closes #14284 from yhuai/lead-lag.
This commit is contained in:
Yin Huai 2016-07-25 20:58:07 -07:00
parent f99e34e8e5
commit 815f3eece5
4 changed files with 108 additions and 40 deletions

View file

@ -1787,7 +1787,8 @@ class Analyzer(
s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
if wf.frame != UnspecifiedFrame =>
WindowExpression(wf, s.copy(frameSpecification = wf.frame))
case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) =>
case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
if e.resolved =>
val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true)
we.copy(windowSpec = s.copy(frameSpecification = frame))
}

View file

@ -321,8 +321,7 @@ abstract class OffsetWindowFunction
val input: Expression
/**
* Default result value for the function when the input expression returns NULL. The default will
* evaluated against the current row instead of the offset row.
* Default result value for the function when the 'offset'th row does not exist.
*/
val default: Expression
@ -348,7 +347,7 @@ abstract class OffsetWindowFunction
*/
override def foldable: Boolean = false
override def nullable: Boolean = default == null || default.nullable
override def nullable: Boolean = default == null || default.nullable || input.nullable
override lazy val frame = {
// This will be triggered by the Analyzer.
@ -373,20 +372,22 @@ abstract class OffsetWindowFunction
}
/**
* The Lead function returns the value of 'x' at 'offset' rows after the current row in the window.
* Offsets start at 0, which is the current row. The offset must be constant integer value. The
* default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger
* than the window, the default expression is evaluated.
*
* This documentation has been based upon similar documentation for the Hive and Presto projects.
* The Lead function returns the value of 'x' at the 'offset'th row after the current row in
* the window. Offsets start at 0, which is the current row. The offset must be constant
* integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row,
* null is returned. If there is no such offset row, the default expression is evaluated.
*
* @param input expression to evaluate 'offset' rows after the current row.
* @param offset rows to jump ahead in the partition.
* @param default to use when the input value is null or when the offset is larger than the window.
* @param default to use when the offset is larger than the window. The default value is null.
*/
@ExpressionDescription(usage =
"""_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows
after the current row in the window""")
"""_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 'offset'th row
after the current row in the window.
The default value of 'offset' is 1 and the default value of 'default' is null.
If the value of 'x' at the 'offset'th row is null, null is returned.
If there is no such offset row (e.g. when the offset is 1, the last row of the window
does not have any subsequent row), 'default' is returned.""")
case class Lead(input: Expression, offset: Expression, default: Expression)
extends OffsetWindowFunction {
@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, default: Expression)
}
/**
* The Lag function returns the value of 'x' at 'offset' rows before the current row in the window.
* Offsets start at 0, which is the current row. The offset must be constant integer value. The
* default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller
* than the window, the default expression is evaluated.
*
* This documentation has been based upon similar documentation for the Hive and Presto projects.
* The Lag function returns the value of 'x' at the 'offset'th row before the current row in
* the window. Offsets start at 0, which is the current row. The offset must be constant
* integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row,
* null is returned. If there is no such offset row, the default expression is evaluated.
*
* @param input expression to evaluate 'offset' rows before the current row.
* @param offset rows to jump back in the partition.
* @param default to use when the input value is null or when the offset is smaller than the window.
* @param default to use when the offset row does not exist.
*/
@ExpressionDescription(usage =
"""_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows
before the current row in the window""")
"""_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 'offset'th row
before the current row in the window.
The default value of 'offset' is 1 and the default value of 'default' is null.
If the value of 'x' at the 'offset'th row is null, null is returned.
If there is no such offset row (e.g. when the offset is 1, the first row of the window
does not have any previous row), 'default' is returned.""")
case class Lag(input: Expression, offset: Expression, default: Expression)
extends OffsetWindowFunction {

View file

@ -582,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame(
/** Row used to combine the offset and the current row. */
private[this] val join = new JoinedRow
/** Create the projection. */
/**
* Create the projection used when the offset row exists.
* Please note that this project always respect null input values (like PostgreSQL).
*/
private[this] val projection = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
case e: OffsetWindowFunction =>
val input = BindReferences.bindReference(e.input, inputAttrs)
input
case e =>
BindReferences.bindReference(e, inputAttrs)
}
// Create the projection.
newMutableProjection(boundExpressions, Nil).target(target)
}
/** Create the projection used when the offset row DOES NOT exists. */
private[this] val fillDefaultValue = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
val numInputAttributes = inputAttrs.size
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
case e: OffsetWindowFunction =>
val input = BindReferences.bindReference(e.input, inputAttrs)
if (e.default == null || e.default.foldable && e.default.eval() == null) {
// Without default value.
input
// The default value is null.
Literal.create(null, e.dataType)
} else {
// With default value.
// The default value is an expression.
val default = BindReferences.bindReference(e.default, inputAttrs).transform {
// Shift the input reference to its default version.
case BoundReference(o, dataType, nullable) =>
BoundReference(o + numInputAttributes, dataType, nullable)
}
org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil)
default
}
case e =>
BindReferences.bindReference(e, inputAttrs)
@ -625,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame(
if (inputIndex >= 0 && inputIndex < input.size) {
val r = input.next()
join(r, current)
projection(join)
} else {
join(emptyRow, current)
// Use default values since the offset row does not exist.
fillDefaultValue(join)
}
projection(join)
inputIndex += 1
}
}

View file

@ -15,12 +15,10 @@
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution
package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
case class WindowData(month: Int, area: String, product: Int)
@ -28,8 +26,9 @@ case class WindowData(month: Int, area: String, product: Int)
/**
* Test suite for SQL window functions.
*/
class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import spark.implicits._
class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("window function: udaf with aggregate expression") {
val data = Seq(
@ -357,14 +356,59 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi
}
test("SPARK-7595: Window will cause resolve failed with self join") {
sql("SELECT * FROM src") // Force loading of src table.
checkAnswer(sql(
"""
|with
| v1 as (select key, count(value) over (partition by key) cnt_val from src),
| v0 as (select 0 as key, 1 as value),
| v1 as (select key, count(value) over (partition by key) cnt_val from v0),
| v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key)
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
| select key, cnt_val from v2 order by key limit 1
""".stripMargin), Row(0, 1))
}
test("SPARK-16633: lead/lag should return the default value if the offset row does not exist") {
checkAnswer(sql(
"""
|SELECT
| lag(123, 100, 321) OVER (ORDER BY id) as lag,
| lead(123, 100, 321) OVER (ORDER BY id) as lead
|FROM (SELECT 1 as id) tmp
""".stripMargin),
Row(321, 321))
checkAnswer(sql(
"""
|SELECT
| lag(123, 100, a) OVER (ORDER BY id) as lag,
| lead(123, 100, a) OVER (ORDER BY id) as lead
|FROM (SELECT 1 as id, 2 as a) tmp
""".stripMargin),
Row(2, 2))
}
test("lead/lag should respect null values") {
checkAnswer(sql(
"""
|SELECT
| b,
| lag(a, 1, 321) OVER (ORDER BY b) as lag,
| lead(a, 1, 321) OVER (ORDER BY b) as lead
|FROM (SELECT cast(null as int) as a, 1 as b
| UNION ALL
| select cast(null as int) as id, 2 as b) tmp
""".stripMargin),
Row(1, 321, null) :: Row(2, null, 321) :: Nil)
checkAnswer(sql(
"""
|SELECT
| b,
| lag(a, 1, c) OVER (ORDER BY b) as lag,
| lead(a, 1, c) OVER (ORDER BY b) as lead
|FROM (SELECT cast(null as int) as a, 1 as b, 3 as c
| UNION ALL
| select cast(null as int) as id, 2 as b, 4 as c) tmp
""".stripMargin),
Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}
}