[SPARK-32338][SQL][PYSPARK][FOLLOW-UP][TEST] Add more tests for slice function

### What changes were proposed in this pull request?

This PR is a follow-up of #29138 and #29195 to add more tests for `slice` function.

### Why are the changes needed?

The original PRs are missing tests with column-based arguments instead of literals.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added tests and existing tests.

Closes #31159 from ueshin/issues/SPARK-32338/slice_tests.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-01-13 09:56:38 +09:00 committed by HyukjinKwon
parent 65222b7051
commit ad8e40e2ab
2 changed files with 18 additions and 1 deletions

View file

@ -350,7 +350,7 @@ class FunctionsTests(ReusedSQLTestCase):
self.assertEqual(result[0], '')
def test_slice(self):
from pyspark.sql.functions import slice, lit
from pyspark.sql.functions import lit, size, slice
df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
@ -359,6 +359,15 @@ class FunctionsTests(ReusedSQLTestCase):
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
)
self.assertEqual(
df.select(slice(df.x, size(df.x) - 1, lit(1)).alias("sliced")).collect(),
[Row(sliced=[2]), Row(sliced=[4])]
)
self.assertEqual(
df.select(slice(df.x, lit(1), size(df.x) - 1).alias("sliced")).collect(),
[Row(sliced=[1, 2]), Row(sliced=[4])]
)
def test_array_repeat(self):
from pyspark.sql.functions import array_repeat, lit

View file

@ -943,6 +943,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative)
checkAnswer(df.select(slice(df("x"), lit(-1), lit(1))), answerNegative)
checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative)
val answerStartExpr = Seq(Row(Seq(2)), Row(Seq(4)))
checkAnswer(df.select(slice(df("x"), size($"x") - 1, lit(1))), answerStartExpr)
checkAnswer(df.selectExpr("slice(x, size(x) - 1, 1)"), answerStartExpr)
val answerLengthExpr = Seq(Row(Seq(1, 2)), Row(Seq(4)))
checkAnswer(df.select(slice(df("x"), lit(1), size($"x") - 1)), answerLengthExpr)
checkAnswer(df.selectExpr("slice(x, 1, size(x) - 1)"), answerLengthExpr)
}
test("array_join function") {