2018-11-14 01:51:11 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
import datetime
|
[SPARK-30569][SQL][PYSPARK][SPARKR] Add percentile_approx DSL functions
### What changes were proposed in this pull request?
- Adds following overloaded variants to Scala `o.a.s.sql.functions`:
- `percentile_approx(e: Column, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Double, accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Double, accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Seq[Double], accuracy: Long): Column` (primarily for
Python interop).
- `percentile_approx(columnName: String, percentage: Seq[Double], accuracy: Long): Column`
- Adds `percentile_approx` to `pyspark.sql.functions`.
- Adds `percentile_approx` function to SparkR.
### Why are the changes needed?
Currently we support `percentile_approx` only in SQL expression. It is inconvenient and makes this function relatively unknown.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
New unit tests for SparkR an PySpark.
As for now there are no additional tests in Scala API ‒ `ApproximatePercentile` is well tested and Python (including docstrings) and R tests provide additional tests, so it seems unnecessary.
Closes #27278 from zero323/SPARK-30569.
Lead-authored-by: zero323 <mszymkiewicz@gmail.com>
Co-authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-03-16 21:44:21 -04:00
|
|
|
from itertools import chain
|
|
|
|
import re
|
2018-11-14 01:51:11 -05:00
|
|
|
|
2020-09-29 00:54:00 -04:00
|
|
|
from pyspark.sql import Row, Window
|
[SPARK-30569][SQL][PYSPARK][SPARKR] Add percentile_approx DSL functions
### What changes were proposed in this pull request?
- Adds following overloaded variants to Scala `o.a.s.sql.functions`:
- `percentile_approx(e: Column, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Double, accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Double, accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Seq[Double], accuracy: Long): Column` (primarily for
Python interop).
- `percentile_approx(columnName: String, percentage: Seq[Double], accuracy: Long): Column`
- Adds `percentile_approx` to `pyspark.sql.functions`.
- Adds `percentile_approx` function to SparkR.
### Why are the changes needed?
Currently we support `percentile_approx` only in SQL expression. It is inconvenient and makes this function relatively unknown.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
New unit tests for SparkR an PySpark.
As for now there are no additional tests in Scala API ‒ `ApproximatePercentile` is well tested and Python (including docstrings) and R tests provide additional tests, so it seems unnecessary.
Closes #27278 from zero323/SPARK-30569.
Lead-authored-by: zero323 <mszymkiewicz@gmail.com>
Co-authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-03-16 21:44:21 -04:00
|
|
|
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase
|
|
|
|
|
|
|
|
|
|
|
|
class FunctionsTests(ReusedSQLTestCase):
|
|
|
|
|
|
|
|
def test_explode(self):
|
|
|
|
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
|
|
|
|
d = [
|
|
|
|
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
|
|
|
|
Row(a=1, intlist=[], mapfield={}),
|
|
|
|
Row(a=1, intlist=None, mapfield=None),
|
|
|
|
]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.spark.createDataFrame(rdd)
|
|
|
|
|
|
|
|
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
|
|
|
|
self.assertEqual(result[0][0], 1)
|
|
|
|
self.assertEqual(result[1][0], 2)
|
|
|
|
self.assertEqual(result[2][0], 3)
|
|
|
|
|
|
|
|
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
|
|
|
|
self.assertEqual(result[0][0], "a")
|
|
|
|
self.assertEqual(result[0][1], "b")
|
|
|
|
|
|
|
|
result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
|
|
|
|
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
|
|
|
|
|
|
|
|
result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
|
|
|
|
self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
|
|
|
|
|
|
|
|
result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
|
|
|
|
self.assertEqual(result, [1, 2, 3, None, None])
|
|
|
|
|
|
|
|
result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
|
|
|
|
self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
|
|
|
|
|
|
|
|
def test_basic_functions(self):
|
|
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
|
|
|
df = self.spark.read.json(rdd)
|
|
|
|
df.count()
|
|
|
|
df.collect()
|
|
|
|
df.schema
|
|
|
|
|
|
|
|
# cache and checkpoint
|
|
|
|
self.assertFalse(df.is_cached)
|
|
|
|
df.persist()
|
|
|
|
df.unpersist(True)
|
|
|
|
df.cache()
|
|
|
|
self.assertTrue(df.is_cached)
|
|
|
|
self.assertEqual(2, df.count())
|
|
|
|
|
|
|
|
with self.tempView("temp"):
|
|
|
|
df.createOrReplaceTempView("temp")
|
|
|
|
df = self.spark.sql("select foo from temp")
|
|
|
|
df.count()
|
|
|
|
df.collect()
|
|
|
|
|
|
|
|
def test_corr(self):
|
|
|
|
import math
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
|
|
|
corr = df.stat.corr(u"a", "b")
|
|
|
|
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
|
|
|
|
|
|
|
def test_sampleby(self):
|
2019-03-23 12:26:09 -04:00
|
|
|
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF()
|
2018-11-14 01:51:11 -05:00
|
|
|
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
|
2019-03-23 12:26:09 -04:00
|
|
|
self.assertTrue(sampled.count() == 35)
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
def test_cov(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
|
|
|
cov = df.stat.cov(u"a", "b")
|
|
|
|
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
|
|
|
|
|
|
|
def test_crosstab(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
|
|
|
ct = df.stat.crosstab(u"a", "b").collect()
|
|
|
|
ct = sorted(ct, key=lambda x: x[0])
|
|
|
|
for i, row in enumerate(ct):
|
|
|
|
self.assertEqual(row[0], str(i))
|
|
|
|
self.assertTrue(row[1], 1)
|
|
|
|
self.assertTrue(row[2], 1)
|
|
|
|
|
|
|
|
def test_math_functions(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
|
|
|
from pyspark.sql import functions
|
|
|
|
import math
|
|
|
|
|
|
|
|
def get_values(l):
|
|
|
|
return [j[0] for j in l]
|
|
|
|
|
|
|
|
def assert_close(a, b):
|
|
|
|
c = get_values(b)
|
|
|
|
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
|
|
|
|
return sum(diff) == len(a)
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos(df.a)).collect())
|
|
|
|
assert_close([math.cos(i) for i in range(10)],
|
|
|
|
df.select(functions.cos("a")).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df.a)).collect())
|
|
|
|
assert_close([math.sin(i) for i in range(10)],
|
|
|
|
df.select(functions.sin(df['a'])).collect())
|
|
|
|
assert_close([math.pow(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, df.b)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2)).collect())
|
|
|
|
assert_close([math.pow(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.pow(df.a, 2.0)).collect())
|
|
|
|
assert_close([math.hypot(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.hypot(df.a, df.b)).collect())
|
[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below:
**Before:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 51, in _
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace:
py4j.Py4JException: Method ascii([class java.lang.String]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:276)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 78, in _
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
ValueError: could not convert string to float: id
```
**After:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
DataFrame[ascii(value): int]
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
DataFrame[ATAN2(id, id): double]
```
Note that,
- This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names.
- I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions.
- There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity.
```python
>>> spark.range(1).select(when(lit(True), col("id"))).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| 0|
+--------------------------+
```
```python
>>> spark.range(1).select(when(lit(True), "id")).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| id|
+--------------------------+
```
This PR also fixes as below:
https://github.com/apache/spark/pull/23882 fixed it to:
- Rename `_create_function` to `_create_name_function`
- Define new `_create_function` to take strings as column names.
This PR, I proposes to:
- Revert `_create_name_function` name to `_create_function`.
- Define new `_create_function_over_column` to take strings as column names.
## How was this patch tested?
Some unit tests were added for binary math / string functions.
Closes #24121 from HyukjinKwon/SPARK-26979.
Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2019-03-19 19:06:10 -04:00
|
|
|
assert_close([math.hypot(i, 2 * i) for i in range(10)],
|
|
|
|
df.select(functions.hypot("a", u"b")).collect())
|
|
|
|
assert_close([math.hypot(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.hypot("a", 2)).collect())
|
|
|
|
assert_close([math.hypot(i, 2) for i in range(10)],
|
|
|
|
df.select(functions.hypot(df.a, 2)).collect())
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
def test_rand_functions(self):
|
|
|
|
df = self.df
|
|
|
|
from pyspark.sql import functions
|
|
|
|
rnd = df.select('key', functions.rand()).collect()
|
|
|
|
for row in rnd:
|
|
|
|
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
|
|
|
|
rndn = df.select('key', functions.randn(5)).collect()
|
|
|
|
for row in rndn:
|
|
|
|
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
|
|
|
|
|
|
|
|
# If the specified seed is 0, we should use it.
|
|
|
|
# https://issues.apache.org/jira/browse/SPARK-9691
|
|
|
|
rnd1 = df.select('key', functions.rand(0)).collect()
|
|
|
|
rnd2 = df.select('key', functions.rand(0)).collect()
|
|
|
|
self.assertEqual(sorted(rnd1), sorted(rnd2))
|
|
|
|
|
|
|
|
rndn1 = df.select('key', functions.randn(0)).collect()
|
|
|
|
rndn2 = df.select('key', functions.randn(0)).collect()
|
|
|
|
self.assertEqual(sorted(rndn1), sorted(rndn2))
|
|
|
|
|
|
|
|
def test_string_functions(self):
|
[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below:
**Before:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 51, in _
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace:
py4j.Py4JException: Method ascii([class java.lang.String]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:276)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 78, in _
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
ValueError: could not convert string to float: id
```
**After:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
DataFrame[ascii(value): int]
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
DataFrame[ATAN2(id, id): double]
```
Note that,
- This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names.
- I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions.
- There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity.
```python
>>> spark.range(1).select(when(lit(True), col("id"))).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| 0|
+--------------------------+
```
```python
>>> spark.range(1).select(when(lit(True), "id")).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| id|
+--------------------------+
```
This PR also fixes as below:
https://github.com/apache/spark/pull/23882 fixed it to:
- Rename `_create_function` to `_create_name_function`
- Define new `_create_function` to take strings as column names.
This PR, I proposes to:
- Revert `_create_name_function` name to `_create_function`.
- Define new `_create_function_over_column` to take strings as column names.
## How was this patch tested?
Some unit tests were added for binary math / string functions.
Closes #24121 from HyukjinKwon/SPARK-26979.
Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2019-03-19 19:06:10 -04:00
|
|
|
from pyspark.sql import functions
|
|
|
|
from pyspark.sql.functions import col, lit, _string_functions
|
2018-11-14 01:51:11 -05:00
|
|
|
df = self.spark.createDataFrame([['nick']], schema=['name'])
|
|
|
|
self.assertRaisesRegexp(
|
|
|
|
TypeError,
|
|
|
|
"must be the same type",
|
|
|
|
lambda: df.select(col('name').substr(0, lit(1))))
|
|
|
|
|
[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below:
**Before:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 51, in _
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace:
py4j.Py4JException: Method ascii([class java.lang.String]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:276)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 78, in _
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
ValueError: could not convert string to float: id
```
**After:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
DataFrame[ascii(value): int]
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
DataFrame[ATAN2(id, id): double]
```
Note that,
- This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names.
- I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions.
- There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity.
```python
>>> spark.range(1).select(when(lit(True), col("id"))).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| 0|
+--------------------------+
```
```python
>>> spark.range(1).select(when(lit(True), "id")).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| id|
+--------------------------+
```
This PR also fixes as below:
https://github.com/apache/spark/pull/23882 fixed it to:
- Rename `_create_function` to `_create_name_function`
- Define new `_create_function` to take strings as column names.
This PR, I proposes to:
- Revert `_create_name_function` name to `_create_function`.
- Define new `_create_function_over_column` to take strings as column names.
## How was this patch tested?
Some unit tests were added for binary math / string functions.
Closes #24121 from HyukjinKwon/SPARK-26979.
Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2019-03-19 19:06:10 -04:00
|
|
|
for name in _string_functions.keys():
|
|
|
|
self.assertEqual(
|
|
|
|
df.select(getattr(functions, name)("name")).first()[0],
|
|
|
|
df.select(getattr(functions, name)(col("name"))).first()[0])
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
def test_array_contains_function(self):
|
|
|
|
from pyspark.sql.functions import array_contains
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
|
|
|
|
actual = df.select(array_contains(df.data, "1").alias('b')).collect()
|
|
|
|
self.assertEqual([Row(b=True), Row(b=False)], actual)
|
|
|
|
|
|
|
|
def test_between_function(self):
|
|
|
|
df = self.sc.parallelize([
|
|
|
|
Row(a=1, b=2, c=3),
|
|
|
|
Row(a=2, b=1, c=3),
|
|
|
|
Row(a=4, b=1, c=4)]).toDF()
|
|
|
|
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
|
|
|
|
df.filter(df.a.between(df.b, df.c)).collect())
|
|
|
|
|
|
|
|
def test_dayofweek(self):
|
|
|
|
from pyspark.sql.functions import dayofweek
|
|
|
|
dt = datetime.datetime(2017, 11, 6)
|
|
|
|
df = self.spark.createDataFrame([Row(date=dt)])
|
|
|
|
row = df.select(dayofweek(df.date)).first()
|
|
|
|
self.assertEqual(row[0], 2)
|
|
|
|
|
|
|
|
def test_expr(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
row = Row(a="length string", b=75)
|
|
|
|
df = self.spark.createDataFrame([row])
|
|
|
|
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
|
|
|
|
self.assertEqual(13, result["length(a)"])
|
|
|
|
|
|
|
|
# add test for SPARK-10577 (test broadcast join hint)
|
|
|
|
def test_functions_broadcast(self):
|
|
|
|
from pyspark.sql.functions import broadcast
|
|
|
|
|
|
|
|
df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
|
|
|
|
|
|
|
|
# equijoin - should be converted into broadcast join
|
|
|
|
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# no join key -- should not be a broadcast join
|
|
|
|
plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
|
|
|
|
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
|
|
|
|
|
|
|
|
# planner should not crash without a join
|
|
|
|
broadcast(df1)._jdf.queryExecution().executedPlan()
|
|
|
|
|
|
|
|
def test_first_last_ignorenulls(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
df = self.spark.range(0, 100)
|
|
|
|
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
|
|
|
|
df3 = df2.select(functions.first(df2.id, False).alias('a'),
|
|
|
|
functions.first(df2.id, True).alias('b'),
|
|
|
|
functions.last(df2.id, False).alias('c'),
|
|
|
|
functions.last(df2.id, True).alias('d'))
|
|
|
|
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
|
|
|
|
|
|
|
|
def test_approxQuantile(self):
|
|
|
|
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
|
|
|
for f in ["a", u"a"]:
|
|
|
|
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aq, list))
|
|
|
|
self.assertEqual(len(aq), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aq))
|
|
|
|
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aqs, list))
|
|
|
|
self.assertEqual(len(aqs), 2)
|
|
|
|
self.assertTrue(isinstance(aqs[0], list))
|
|
|
|
self.assertEqual(len(aqs[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
|
|
|
|
self.assertTrue(isinstance(aqs[1], list))
|
|
|
|
self.assertEqual(len(aqs[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
|
|
|
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
|
|
|
|
self.assertTrue(isinstance(aqt, list))
|
|
|
|
self.assertEqual(len(aqt), 2)
|
|
|
|
self.assertTrue(isinstance(aqt[0], list))
|
|
|
|
self.assertEqual(len(aqt[0]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
|
|
|
|
self.assertTrue(isinstance(aqt[1], list))
|
|
|
|
self.assertEqual(len(aqt[1]), 3)
|
|
|
|
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
|
|
|
|
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
|
|
|
|
|
|
|
|
def test_sort_with_nulls_order(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame(
|
|
|
|
[('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
|
|
|
|
[Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
|
|
|
|
[Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
|
|
|
|
[Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
|
|
|
|
[Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
|
|
|
|
|
2019-05-22 21:35:50 -04:00
|
|
|
def test_input_file_name_reset_for_rdd(self):
|
|
|
|
rdd = self.sc.textFile('python/test_support/hello/hello.txt').map(lambda x: {'data': x})
|
|
|
|
df = self.spark.createDataFrame(rdd, "data STRING")
|
|
|
|
df.select(input_file_name().alias('file')).collect()
|
|
|
|
|
|
|
|
non_file_df = self.spark.range(100).select(input_file_name())
|
|
|
|
|
|
|
|
results = non_file_df.collect()
|
|
|
|
self.assertTrue(len(results) == 100)
|
|
|
|
|
|
|
|
# [SPARK-24605]: if everything was properly reset after the last job, this should return
|
|
|
|
# empty string rather than the file read in the last job.
|
|
|
|
for result in results:
|
|
|
|
self.assertEqual(result[0], '')
|
|
|
|
|
2020-07-23 00:53:50 -04:00
|
|
|
def test_slice(self):
|
|
|
|
from pyspark.sql.functions import slice, lit
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
|
|
|
|
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(slice(df.x, 2, 2).alias("sliced")).collect(),
|
|
|
|
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
|
|
|
|
)
|
|
|
|
|
2019-07-18 15:58:48 -04:00
|
|
|
def test_array_repeat(self):
|
|
|
|
from pyspark.sql.functions import array_repeat, lit
|
|
|
|
|
|
|
|
df = self.spark.range(1)
|
|
|
|
|
|
|
|
self.assertEquals(
|
|
|
|
df.select(array_repeat("id", 3)).toDF("val").collect(),
|
|
|
|
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
|
|
|
|
)
|
|
|
|
|
2019-07-31 10:40:01 -04:00
|
|
|
def test_input_file_name_udf(self):
|
|
|
|
df = self.spark.read.text('python/test_support/hello/hello.txt')
|
|
|
|
df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
|
|
|
|
file_name = df.collect()[0].file
|
|
|
|
self.assertTrue("python/test_support/hello/hello.txt" in file_name)
|
|
|
|
|
2020-01-23 02:16:47 -05:00
|
|
|
def test_overlay(self):
|
|
|
|
from pyspark.sql.functions import col, lit, overlay
|
|
|
|
from itertools import chain
|
|
|
|
import re
|
|
|
|
|
|
|
|
actual = list(chain.from_iterable([
|
|
|
|
re.findall("(overlay\\(.*\\))", str(x)) for x in [
|
|
|
|
overlay(col("foo"), col("bar"), 1),
|
|
|
|
overlay("x", "y", 3),
|
|
|
|
overlay(col("x"), col("y"), 1, 3),
|
|
|
|
overlay("x", "y", 2, 5),
|
|
|
|
overlay("x", "y", lit(11)),
|
|
|
|
overlay("x", "y", lit(2), lit(5)),
|
|
|
|
]
|
|
|
|
]))
|
|
|
|
|
|
|
|
expected = [
|
|
|
|
"overlay(foo, bar, 1, -1)",
|
|
|
|
"overlay(x, y, 3, -1)",
|
|
|
|
"overlay(x, y, 1, 3)",
|
|
|
|
"overlay(x, y, 2, 5)",
|
|
|
|
"overlay(x, y, 11, -1)",
|
|
|
|
"overlay(x, y, 2, 5)",
|
|
|
|
]
|
|
|
|
|
|
|
|
self.assertListEqual(actual, expected)
|
|
|
|
|
[SPARK-30569][SQL][PYSPARK][SPARKR] Add percentile_approx DSL functions
### What changes were proposed in this pull request?
- Adds following overloaded variants to Scala `o.a.s.sql.functions`:
- `percentile_approx(e: Column, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Array[Double], accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Double, accuracy: Long): Column`
- `percentile_approx(columnName: String, percentage: Double, accuracy: Long): Column`
- `percentile_approx(e: Column, percentage: Seq[Double], accuracy: Long): Column` (primarily for
Python interop).
- `percentile_approx(columnName: String, percentage: Seq[Double], accuracy: Long): Column`
- Adds `percentile_approx` to `pyspark.sql.functions`.
- Adds `percentile_approx` function to SparkR.
### Why are the changes needed?
Currently we support `percentile_approx` only in SQL expression. It is inconvenient and makes this function relatively unknown.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
New unit tests for SparkR an PySpark.
As for now there are no additional tests in Scala API ‒ `ApproximatePercentile` is well tested and Python (including docstrings) and R tests provide additional tests, so it seems unnecessary.
Closes #27278 from zero323/SPARK-30569.
Lead-authored-by: zero323 <mszymkiewicz@gmail.com>
Co-authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-03-16 21:44:21 -04:00
|
|
|
def test_percentile_approx(self):
|
|
|
|
actual = list(chain.from_iterable([
|
|
|
|
re.findall("(percentile_approx\\(.*\\))", str(x)) for x in [
|
|
|
|
percentile_approx(col("foo"), lit(0.5)),
|
|
|
|
percentile_approx(col("bar"), 0.25, 42),
|
|
|
|
percentile_approx(col("bar"), [0.25, 0.5, 0.75]),
|
|
|
|
percentile_approx(col("foo"), (0.05, 0.95), 100),
|
|
|
|
percentile_approx("foo", 0.5),
|
|
|
|
percentile_approx("bar", [0.1, 0.9], lit(10)),
|
|
|
|
]
|
|
|
|
]))
|
|
|
|
|
|
|
|
expected = [
|
|
|
|
"percentile_approx(foo, 0.5, 10000)",
|
|
|
|
"percentile_approx(bar, 0.25, 42)",
|
|
|
|
"percentile_approx(bar, array(0.25, 0.5, 0.75), 10000)",
|
|
|
|
"percentile_approx(foo, array(0.05, 0.95), 100)",
|
|
|
|
"percentile_approx(foo, 0.5, 10000)",
|
|
|
|
"percentile_approx(bar, array(0.1, 0.9), 10)"
|
|
|
|
]
|
|
|
|
|
|
|
|
self.assertListEqual(actual, expected)
|
|
|
|
|
2020-09-29 01:14:28 -04:00
|
|
|
def test_nth_value(self):
|
|
|
|
from pyspark.sql import Window
|
|
|
|
from pyspark.sql.functions import nth_value
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([
|
|
|
|
("a", 0, None),
|
|
|
|
("a", 1, "x"),
|
|
|
|
("a", 2, "y"),
|
|
|
|
("a", 3, "z"),
|
|
|
|
("a", 4, None),
|
|
|
|
("b", 1, None),
|
|
|
|
("b", 2, None)], schema=("key", "order", "value"))
|
|
|
|
w = Window.partitionBy("key").orderBy("order")
|
|
|
|
|
|
|
|
rs = df.select(
|
|
|
|
df.key,
|
|
|
|
df.order,
|
|
|
|
nth_value("value", 2).over(w),
|
|
|
|
nth_value("value", 2, False).over(w),
|
|
|
|
nth_value("value", 2, True).over(w)).collect()
|
|
|
|
|
|
|
|
expected = [
|
|
|
|
("a", 0, None, None, None),
|
|
|
|
("a", 1, "x", "x", None),
|
|
|
|
("a", 2, "x", "x", "y"),
|
|
|
|
("a", 3, "x", "x", "y"),
|
|
|
|
("a", 4, "x", "x", "y"),
|
|
|
|
("b", 1, None, None, None),
|
|
|
|
("b", 2, None, None, None)
|
|
|
|
]
|
|
|
|
|
|
|
|
for r, ex in zip(sorted(rs), sorted(expected)):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
[SPARK-30681][PYSPARK][SQL] Add higher order functions API to PySpark
### What changes were proposed in this pull request?
This PR add Python API for invoking following higher functions:
- `transform`
- `exists`
- `forall`
- `filter`
- `aggregate`
- `zip_with`
- `transform_keys`
- `transform_values`
- `map_filter`
- `map_zip_with`
to `pyspark.sql`. Each of these accepts plain Python functions of one of the following types
- `(Column) -> Column: ...`
- `(Column, Column) -> Column: ...`
- `(Column, Column, Column) -> Column: ...`
Internally this proposal piggbacks on objects supporting Scala implementation ([SPARK-27297](https://issues.apache.org/jira/browse/SPARK-27297)) by:
1. Creating required `UnresolvedNamedLambdaVariables` exposing these as PySpark `Columns`
2. Invoking Python function with these columns as arguments.
3. Using the result, and underlying JVM objects from 1., to create `expressions.LambdaFunction` which is passed to desired expression, and repacked as Python `Column`.
### Why are the changes needed?
Currently higher order functions are available only using SQL and Scala API and can use only SQL expressions
```python
df.selectExpr("transform(values, x -> x + 1)")
```
This works reasonably well for simple functions, but can get really ugly with complex functions (complex functions, casts), resulting objects are somewhat verbose and we don't get any IDE support. Additionally DSL used, though very simple, is not documented.
With changes propose here, above query could be rewritten as:
```python
df.select(transform("values", lambda x: x + 1))
```
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
- For positive cases this PR adds doctest strings covering possible usage patterns.
- For negative cases (unsupported function types) this PR adds unit tests.
### Notes
If approved, the same approach can be used in SparkR.
Closes #27406 from zero323/SPARK-30681.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-02-27 22:59:39 -05:00
|
|
|
def test_higher_order_function_failures(self):
|
2020-08-08 11:51:57 -04:00
|
|
|
from pyspark.sql.functions import col, transform
|
[SPARK-30681][PYSPARK][SQL] Add higher order functions API to PySpark
### What changes were proposed in this pull request?
This PR add Python API for invoking following higher functions:
- `transform`
- `exists`
- `forall`
- `filter`
- `aggregate`
- `zip_with`
- `transform_keys`
- `transform_values`
- `map_filter`
- `map_zip_with`
to `pyspark.sql`. Each of these accepts plain Python functions of one of the following types
- `(Column) -> Column: ...`
- `(Column, Column) -> Column: ...`
- `(Column, Column, Column) -> Column: ...`
Internally this proposal piggbacks on objects supporting Scala implementation ([SPARK-27297](https://issues.apache.org/jira/browse/SPARK-27297)) by:
1. Creating required `UnresolvedNamedLambdaVariables` exposing these as PySpark `Columns`
2. Invoking Python function with these columns as arguments.
3. Using the result, and underlying JVM objects from 1., to create `expressions.LambdaFunction` which is passed to desired expression, and repacked as Python `Column`.
### Why are the changes needed?
Currently higher order functions are available only using SQL and Scala API and can use only SQL expressions
```python
df.selectExpr("transform(values, x -> x + 1)")
```
This works reasonably well for simple functions, but can get really ugly with complex functions (complex functions, casts), resulting objects are somewhat verbose and we don't get any IDE support. Additionally DSL used, though very simple, is not documented.
With changes propose here, above query could be rewritten as:
```python
df.select(transform("values", lambda x: x + 1))
```
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
- For positive cases this PR adds doctest strings covering possible usage patterns.
- For negative cases (unsupported function types) this PR adds unit tests.
### Notes
If approved, the same approach can be used in SparkR.
Closes #27406 from zero323/SPARK-30681.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-02-27 22:59:39 -05:00
|
|
|
|
|
|
|
# Should fail with varargs
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
transform(col("foo"), lambda *x: lit(1))
|
|
|
|
|
|
|
|
# Should fail with kwargs
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
transform(col("foo"), lambda **x: lit(1))
|
|
|
|
|
|
|
|
# Should fail with nullary function
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
transform(col("foo"), lambda: lit(1))
|
|
|
|
|
|
|
|
# Should fail with quaternary function
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
transform(col("foo"), lambda x1, x2, x3, x4: lit(1))
|
|
|
|
|
|
|
|
# Should fail if function doesn't return Column
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
transform(col("foo"), lambda x: 1)
|
|
|
|
|
2020-09-29 00:54:00 -04:00
|
|
|
def test_window_functions(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
w = Window.partitionBy("value").orderBy("key")
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
|
|
|
F.row_number().over(w),
|
|
|
|
F.rank().over(w),
|
|
|
|
F.dense_rank().over(w),
|
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 1, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 3, 1, 1, 1, 1),
|
|
|
|
("2", 1, 2, 1, 3, 2, 1, 1, 1),
|
|
|
|
("2", 2, 2, 2, 3, 3, 3, 2, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
def test_window_functions_without_partitionBy(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
w = Window.orderBy("key", df.value)
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
sel = df.select(df.value, df.key,
|
|
|
|
F.max("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.min("key").over(w.rowsBetween(0, 1)),
|
|
|
|
F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
|
|
|
|
F.row_number().over(w),
|
|
|
|
F.rank().over(w),
|
|
|
|
F.dense_rank().over(w),
|
|
|
|
F.ntile(2).over(w))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [
|
|
|
|
("1", 1, 1, 1, 4, 1, 1, 1, 1),
|
|
|
|
("2", 1, 1, 1, 4, 2, 2, 2, 1),
|
|
|
|
("2", 1, 2, 1, 4, 3, 2, 2, 2),
|
|
|
|
("2", 2, 2, 2, 4, 4, 4, 3, 2)
|
|
|
|
]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
def test_window_functions_cumulative_sum(self):
|
|
|
|
df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
|
|
|
|
from pyspark.sql import functions as F
|
|
|
|
|
|
|
|
# Test cumulative sum
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
# Test boundary values less than JVM's Long.MinValue and make sure we don't overflow
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 1), ("two", 3)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
# Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow
|
|
|
|
frame_end = Window.unboundedFollowing + 1
|
|
|
|
sel = df.select(
|
|
|
|
df.key,
|
|
|
|
F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)))
|
|
|
|
rs = sorted(sel.collect())
|
|
|
|
expected = [("one", 3), ("two", 2)]
|
|
|
|
for r, ex in zip(rs, expected):
|
|
|
|
self.assertEqual(tuple(r), ex[:len(r)])
|
|
|
|
|
|
|
|
def test_collect_functions(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
|
|
|
from pyspark.sql import functions
|
|
|
|
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
|
|
|
|
[1, 1, 1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2"])
|
|
|
|
self.assertEqual(
|
|
|
|
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
|
|
|
|
["1", "2", "2", "2"])
|
|
|
|
|
|
|
|
def test_datetime_functions(self):
|
|
|
|
from pyspark.sql import functions
|
|
|
|
from datetime import date
|
|
|
|
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
|
|
|
|
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
|
|
|
|
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import unittest
|
2020-08-08 11:51:57 -04:00
|
|
|
from pyspark.sql.tests.test_functions import * # noqa: F401
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
try:
|
2020-09-24 01:15:36 -04:00
|
|
|
import xmlrunner # type: ignore[import]
|
2019-06-23 20:58:17 -04:00
|
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
2018-11-14 01:51:11 -05:00
|
|
|
except ImportError:
|
2018-11-14 23:30:52 -05:00
|
|
|
testRunner = None
|
|
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|