diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1d2c59cae..61162ccdba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1787,7 +1787,8 @@ class Analyzer( s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if e.resolved => val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) we.copy(windowSpec = s.copy(frameSpecification = frame)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index e35192ca2d..6806591f68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -321,8 +321,7 @@ abstract class OffsetWindowFunction val input: Expression /** - * Default result value for the function when the input expression returns NULL. The default will - * evaluated against the current row instead of the offset row. + * Default result value for the function when the 'offset'th row does not exist. */ val default: Expression @@ -348,7 +347,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = false - override def nullable: Boolean = default == null || default.nullable + override def nullable: Boolean = default == null || default.nullable || input.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -373,20 +372,22 @@ abstract class OffsetWindowFunction } /** - * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lead function returns the value of 'x' at the 'offset'th row after the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows after the current row. * @param offset rows to jump ahead in the partition. - * @param default to use when the input value is null or when the offset is larger than the window. + * @param default to use when the offset is larger than the window. The default value is null. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows - after the current row in the window""") + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 'offset'th row + after the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the last row of the window + does not have any subsequent row), 'default' is returned.""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, default: Expression) } /** - * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lag function returns the value of 'x' at the 'offset'th row before the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows before the current row. * @param offset rows to jump back in the partition. - * @param default to use when the input value is null or when the offset is smaller than the window. + * @param default to use when the offset row does not exist. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows - before the current row in the window""") + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 'offset'th row + before the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the first row of the window + does not have any previous row), 'default' is returned.""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 93f007f5b3..7149603018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -582,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame( /** Row used to combine the offset and the current row. */ private[this] val join = new JoinedRow - /** Create the projection. */ + /** + * Create the projection used when the offset row exists. + * Please note that this project always respect null input values (like PostgreSQL). + */ private[this] val projection = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputAttrs) + input + case e => + BindReferences.bindReference(e, inputAttrs) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + /** Create the projection used when the offset row DOES NOT exists. */ + private[this] val fillDefaultValue = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) val numInputAttributes = inputAttrs.size val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => - val input = BindReferences.bindReference(e.input, inputAttrs) if (e.default == null || e.default.foldable && e.default.eval() == null) { - // Without default value. - input + // The default value is null. + Literal.create(null, e.dataType) } else { - // With default value. + // The default value is an expression. val default = BindReferences.bindReference(e.default, inputAttrs).transform { // Shift the input reference to its default version. case BoundReference(o, dataType, nullable) => BoundReference(o + numInputAttributes, dataType, nullable) } - org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + default } case e => BindReferences.bindReference(e, inputAttrs) @@ -625,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame( if (inputIndex >= 0 && inputIndex < input.size) { val r = input.next() join(r, current) + projection(join) } else { join(emptyRow, current) + // Use default values since the offset row does not exist. + fillDefaultValue(join) } - projection(join) inputIndex += 1 } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala similarity index 85% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 77e97dff8c..d3cfa953a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -15,12 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils - +import org.apache.spark.sql.test.SharedSQLContext case class WindowData(month: Int, area: String, product: Int) @@ -28,8 +26,9 @@ case class WindowData(month: Int, area: String, product: Int) /** * Test suite for SQL window functions. */ -class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import spark.implicits._ +class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ test("window function: udaf with aggregate expression") { val data = Seq( @@ -357,14 +356,59 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi } test("SPARK-7595: Window will cause resolve failed with self join") { - sql("SELECT * FROM src") // Force loading of src table. - checkAnswer(sql( """ |with - | v1 as (select key, count(value) over (partition by key) cnt_val from src), + | v0 as (select 0 as key, 1 as value), + | v1 as (select key, count(value) over (partition by key) cnt_val from v0), | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) - | select * from v2 order by key limit 1 - """.stripMargin), Row(0, 3)) + | select key, cnt_val from v2 order by key limit 1 + """.stripMargin), Row(0, 1)) + } + + test("SPARK-16633: lead/lag should return the default value if the offset row does not exist") { + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, 321) OVER (ORDER BY id) as lag, + | lead(123, 100, 321) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id) tmp + """.stripMargin), + Row(321, 321)) + + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, a) OVER (ORDER BY id) as lag, + | lead(123, 100, a) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id, 2 as a) tmp + """.stripMargin), + Row(2, 2)) + } + + test("lead/lag should respect null values") { + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, 321) OVER (ORDER BY b) as lag, + | lead(a, 1, 321) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b + | UNION ALL + | select cast(null as int) as id, 2 as b) tmp + """.stripMargin), + Row(1, 321, null) :: Row(2, null, 321) :: Nil) + + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, c) OVER (ORDER BY b) as lag, + | lead(a, 1, c) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b, 3 as c + | UNION ALL + | select cast(null as int) as id, 2 as b, 4 as c) tmp + """.stripMargin), + Row(1, 3, null) :: Row(2, null, 4) :: Nil) } }