[SPARK-22361][SQL][TEST] Add unit test for Window Frames

## What changes were proposed in this pull request?

There are already quite a few integration tests using window frames, but the unit tests coverage is not ideal.

In this PR the already existing tests are reorganized, extended and where gaps found additional cases added.

## How was this patch tested?

Automated: Pass the Jenkins.

Author: Gabor Somogyi <gabor.g.somogyi@gmail.com>

Closes #20019 from gaborgsomogyi/SPARK-22361.
This commit is contained in:
Gabor Somogyi 2018-01-17 10:03:25 +08:00 committed by gatorsmile
parent 0c2ba427bc
commit a9b845ebb5
3 changed files with 455 additions and 252 deletions

View file

@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc)))
assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc)))
assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest {
"sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
WindowExpression('sum.function('product + 1), WindowExpression('sum.function('product + 1),
WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
}
test("range/rows window function expressions") {
val func = 'foo.function(star())
def windowed(
partitioning: Seq[Expression] = Seq.empty,
ordering: Seq[SortOrder] = Seq.empty,
frame: WindowFrame = UnspecifiedFrame): Expression = {
WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
}
// Range/Row
val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
val boundaries = Seq( val boundaries = Seq(
("10 preceding", -Literal(10), CurrentRow), // No between combinations
("2147483648 preceding", -Literal(2147483648L), CurrentRow),
("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow),
("unbounded preceding", UnboundedPreceding, CurrentRow), ("unbounded preceding", UnboundedPreceding, CurrentRow),
("2147483648 preceding", -Literal(2147483648L), CurrentRow),
("10 preceding", -Literal(10), CurrentRow),
("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow),
("0 preceding", -Literal(0), CurrentRow),
("current row", CurrentRow, CurrentRow),
("0 following", Literal(0), CurrentRow),
("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow),
("10 following", Literal(10), CurrentRow),
("2147483649 following", Literal(2147483649L), CurrentRow),
("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
// Between combinations
("between unbounded preceding and 5 following",
UnboundedPreceding, Literal(5)),
("between unbounded preceding and 3 + 1 following",
UnboundedPreceding, Add(Literal(3), Literal(1))),
("between unbounded preceding and 2147483649 following",
UnboundedPreceding, Literal(2147483649L)),
("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
("between unbounded preceding and unbounded following", ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow),
UnboundedPreceding, UnboundedFollowing),
("between 10 preceding and current row", -Literal(10), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow),
("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow),
("between 0 preceding and current row", -Literal(0), CurrentRow),
("between current row and current row", CurrentRow, CurrentRow),
("between current row and 0 following", CurrentRow, Literal(0)),
("between current row and 5 following", CurrentRow, Literal(5)), ("between current row and 5 following", CurrentRow, Literal(5)),
("between 10 preceding and 5 following", -Literal(10), Literal(5)) ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))),
("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)),
("between current row and unbounded following", CurrentRow, UnboundedFollowing),
("between 2147483648 preceding and unbounded following",
-Literal(2147483648L), UnboundedFollowing),
("between 10 preceding and unbounded following",
-Literal(10), UnboundedFollowing),
("between 3 + 1 preceding and unbounded following",
-Add(Literal(3), Literal(1)), UnboundedFollowing),
("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing),
// Between partial and full range
("between 10 preceding and 5 following", -Literal(10), Literal(5)),
("between unbounded preceding and unbounded following",
UnboundedPreceding, UnboundedFollowing)
) )
frameTypes.foreach { frameTypes.foreach {
case (frameTypeSql, frameType) => case (frameTypeSql, frameType) =>

View file

@ -0,0 +1,405 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.CalendarInterval
/**
* Window frame testing for DataFrame API.
*/
class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("lead/lag with empty data frame") {
val df = Seq.empty[(Int, String)].toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
lead("value", 1).over(window),
lag("value", 1).over(window)),
Nil)
}
test("lead/lag with positive offset") {
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
$"key",
lead("value", 1).over(window),
lag("value", 1).over(window)),
Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil)
}
test("reverse lead/lag with positive offset") {
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value".desc)
checkAnswer(
df.select(
$"key",
lead("value", 1).over(window),
lag("value", 1).over(window)),
Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil)
}
test("lead/lag with negative offset") {
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
$"key",
lead("value", -1).over(window),
lag("value", -1).over(window)),
Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil)
}
test("reverse lead/lag with negative offset") {
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value".desc)
checkAnswer(
df.select(
$"key",
lead("value", -1).over(window),
lag("value", -1).over(window)),
Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil)
}
test("lead/lag with default value") {
val default = "n/a"
val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
$"key",
lead("value", 2, default).over(window),
lag("value", 2, default).over(window),
lead("value", -2, default).over(window),
lag("value", -2, default).over(window)),
Row(1, default, default, default, default) :: Row(1, default, default, default, default) ::
Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) ::
Row(2, default, default, default, default) :: Nil)
}
test("rows/range between with empty data frame") {
val df = Seq.empty[(String, Int)].toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
'key,
first("value").over(
window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
first("value").over(
window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
Nil)
}
test("rows between should accept int/long values as boundary") {
val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))),
Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1))
)
val e = intercept[AnalysisException](
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L))))
assert(e.message.contains("Boundary end is not a valid integer: 2147483648"))
}
test("range between should accept at most one ORDER BY expression when unbounded") {
val df = Seq((1, 1)).toDF("key", "value")
val window = Window.orderBy($"key", $"value")
checkAnswer(
df.select(
$"key",
min("key").over(
window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
Seq(Row(1, 1))
)
val e1 = intercept[AnalysisException](
df.select(
min("key").over(window.rangeBetween(Window.unboundedPreceding, 1))))
assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " +
"window specification with multiple order by expressions"))
val e2 = intercept[AnalysisException](
df.select(
min("key").over(window.rangeBetween(-1, Window.unboundedFollowing))))
assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " +
"window specification with multiple order by expressions"))
val e3 = intercept[AnalysisException](
df.select(
min("key").over(window.rangeBetween(-1, 1))))
assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " +
"window specification with multiple order by expressions"))
}
test("range between should accept numeric values only when bounded") {
val df = Seq("non_numeric").toDF("value")
val window = Window.orderBy($"value")
checkAnswer(
df.select(
$"value",
min("value").over(
window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
Row("non_numeric", "non_numeric") :: Nil)
val e1 = intercept[AnalysisException](
df.select(
min("value").over(window.rangeBetween(Window.unboundedPreceding, 1))))
assert(e1.message.contains("The data type of the upper bound 'string' " +
"does not match the expected data type"))
val e2 = intercept[AnalysisException](
df.select(
min("value").over(window.rangeBetween(-1, Window.unboundedFollowing))))
assert(e2.message.contains("The data type of the lower bound 'string' " +
"does not match the expected data type"))
val e3 = intercept[AnalysisException](
df.select(
min("value").over(window.rangeBetween(-1, 1))))
assert(e3.message.contains("The data type of the lower bound 'string' " +
"does not match the expected data type"))
}
test("range between should accept int/long values as boundary") {
val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))),
Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1))
)
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1))
)
def dt(date: String): Date = Date.valueOf(date)
val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"),
(dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2"))
.toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2))
checkAnswer(
df2.select(
$"key",
count("key").over(window)),
Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1),
Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1))
)
}
test("range between should accept double values as boundary") {
val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"),
(100.001D, "2")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D))
checkAnswer(
df.select(
$"key",
count("key").over(window)),
Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1))
)
}
test("range between should accept interval values as boundary") {
def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000)
val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"),
(ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2"))
.toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key")
.rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours")))
checkAnswer(
df.select(
$"key",
count("key").over(window)),
Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1),
Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1))
)
}
test("unbounded rows/range between with aggregation") {
val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value")
val window = Window.partitionBy($"key").orderBy($"value")
checkAnswer(
df.select(
'key,
sum("value").over(window.
rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
sum("value").over(window.
rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil)
}
test("unbounded preceding/following rows between with aggregation") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key")
checkAnswer(
df.select(
$"key",
last("key").over(
window.rowsBetween(Window.currentRow, Window.unboundedFollowing)),
last("key").over(
window.rowsBetween(Window.unboundedPreceding, Window.currentRow))),
Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) ::
Row(4, 4, 4) :: Nil)
}
test("reverse unbounded preceding/following rows between with aggregation") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key".desc)
checkAnswer(
df.select(
$"key",
last("key").over(
window.rowsBetween(Window.currentRow, Window.unboundedFollowing)),
last("key").over(
window.rowsBetween(Window.unboundedPreceding, Window.currentRow))),
Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) ::
Row(1, 1, 1) :: Nil)
}
test("unbounded preceding/following range between with aggregation") {
val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value")
val window = Window.partitionBy("value").orderBy("key")
checkAnswer(
df.select(
$"key",
avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1))
.as("avg_key1"),
avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing))
.as("avg_key2")),
Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) ::
Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) ::
Row(6, 17.0d / 4.0d, 6.0d) :: Nil)
}
// This is here to illustrate the fact that reverse order also reverses offsets.
test("reverse preceding/following range between with aggregation") {
val df = Seq(1, 2, 4, 3, 2, 1).toDF("value")
val window = Window.orderBy($"value".desc)
checkAnswer(
df.select(
$"value",
sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)),
sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))),
Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) ::
Row(2, 13, 2) :: Row(1, 13, null) :: Nil)
}
test("sliding rows between with aggregation") {
val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2)
checkAnswer(
df.select(
$"key",
avg("key").over(window)),
Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) ::
Row(2, 2.0d) :: Nil)
}
test("reverse sliding rows between with aggregation") {
val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2)
checkAnswer(
df.select(
$"key",
avg("key").over(window)),
Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) ::
Row(2, 2.0d) :: Nil)
}
test("sliding range between with aggregation") {
val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value")
val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1)
checkAnswer(
df.select(
$"key",
avg("key").over(window)),
Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) ::
Row(2, 2.0d) :: Row(2, 2.0d) :: Nil)
}
test("reverse sliding range between with aggregation") {
val df = Seq(
(1, "Thin", "Cell Phone", 6000),
(2, "Normal", "Tablet", 1500),
(3, "Mini", "Tablet", 5500),
(4, "Ultra thin", "Cell Phone", 5500),
(5, "Very thin", "Cell Phone", 6000),
(6, "Big", "Tablet", 2500),
(7, "Bendable", "Cell Phone", 3000),
(8, "Foldable", "Cell Phone", 3000),
(9, "Pro", "Tablet", 4500),
(10, "Pro2", "Tablet", 6500)).
toDF("id", "product", "category", "revenue")
val window = Window.partitionBy($"category").orderBy($"revenue".desc).
rangeBetween(-2000L, 1000L)
checkAnswer(
df.select(
$"id",
avg($"revenue").over(window).cast("int")),
Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) ::
Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) ::
Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) ::
Row(10, 6000) :: Nil)
}
}

View file

@ -55,56 +55,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
} }
test("Window.rowsBetween") {
val df = Seq(("one", 1), ("two", 2)).toDF("key", "value")
// Running (cumulative) sum
checkAnswer(
df.select('key, sum("value").over(
Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))),
Row("one", 1) :: Row("two", 3) :: Nil
)
}
test("lead") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))),
Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil)
}
test("lag") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))),
Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil)
}
test("lead with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))),
Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a")))
}
test("lag with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))),
Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2")))
}
test("rank functions in unspecific window") { test("rank functions in unspecific window") {
val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table") df.createOrReplaceTempView("window_table")
@ -136,199 +86,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("requires window to be ordered")) assert(e.message.contains("requires window to be ordered"))
} }
test("aggregation and rows between") {
val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))),
Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d)))
}
test("aggregation and range between") {
val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))),
Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d),
Row(2.0d), Row(2.0d)))
}
test("row between should accept integer values as boundary") {
val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"),
(3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))),
Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1))
)
val e = intercept[AnalysisException](
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L))))
assert(e.message.contains("Boundary end is not a valid integer: 2147483648"))
}
test("range between should accept int/long values as boundary") {
val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"),
(3L, "2"), (2L, "1"), (2147483650L, "2"))
.toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))),
Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1))
)
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1))
)
def dt(date: String): Date = Date.valueOf(date)
val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"),
(dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2"))
.toDF("key", "value")
checkAnswer(
df2.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))),
Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1),
Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1))
)
}
test("range between should accept double values as boundary") {
val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"),
(3.3D, "2"), (2.02D, "1"), (100.001D, "2"))
.toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key")
.rangeBetween(currentRow, lit(2.5D)))),
Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1))
)
}
test("range between should accept interval values as boundary") {
def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000)
val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"),
(ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2"))
.toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
count("key").over(
Window.partitionBy($"value").orderBy($"key")
.rangeBetween(currentRow,
lit(CalendarInterval.fromString("interval 23 days 4 hours"))))),
Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1),
Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1))
)
}
test("aggregation and rows between with unbounded") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
last("key").over(
Window.partitionBy($"value").orderBy($"key")
.rowsBetween(Window.currentRow, Window.unboundedFollowing)),
last("key").over(
Window.partitionBy($"value").orderBy($"key")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)),
last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))),
Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4),
Row(4, 4, 4, 4)))
}
test("aggregation and range between with unbounded") {
val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value")
df.createOrReplaceTempView("window_table")
checkAnswer(
df.select(
$"key",
last("value").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1))
.equalTo("2")
.as("last_v"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
.as("avg_key1"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
.as("avg_key2"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0))
.as("avg_key3")
),
Seq(Row(3, null, 3.0d, 4.0d, 3.0d),
Row(5, false, 4.0d, 5.0d, 5.0d),
Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d),
Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d),
Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d),
Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d)))
}
test("reverse sliding range frame") {
val df = Seq(
(1, "Thin", "Cell Phone", 6000),
(2, "Normal", "Tablet", 1500),
(3, "Mini", "Tablet", 5500),
(4, "Ultra thin", "Cell Phone", 5500),
(5, "Very thin", "Cell Phone", 6000),
(6, "Big", "Tablet", 2500),
(7, "Bendable", "Cell Phone", 3000),
(8, "Foldable", "Cell Phone", 3000),
(9, "Pro", "Tablet", 4500),
(10, "Pro2", "Tablet", 6500)).
toDF("id", "product", "category", "revenue")
val window = Window.
partitionBy($"category").
orderBy($"revenue".desc).
rangeBetween(-2000L, 1000L)
checkAnswer(
df.select(
$"id",
avg($"revenue").over(window).cast("int")),
Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) ::
Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) ::
Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) ::
Row(10, 6000) :: Nil)
}
// This is here to illustrate the fact that reverse order also reverses offsets.
test("reverse unbounded range frame") {
val df = Seq(1, 2, 4, 3, 2, 1).
map(Tuple1.apply).
toDF("value")
val window = Window.orderBy($"value".desc)
checkAnswer(
df.select(
$"value",
sum($"value").over(window.rangeBetween(Long.MinValue, 1)),
sum($"value").over(window.rangeBetween(1, Long.MaxValue))),
Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) ::
Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil)
}
test("statistical functions") { test("statistical functions") {
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)).
toDF("key", "value") toDF("key", "value")