2018-11-14 01:51:11 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
import functools
|
|
|
|
import pydoc
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
from pyspark import SparkContext
|
|
|
|
from pyspark.sql import SparkSession, Column, Row
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import UserDefinedFunction, udf
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.sql.types import *
|
|
|
|
from pyspark.sql.utils import AnalysisException
|
|
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
|
2018-11-14 23:30:52 -05:00
|
|
|
from pyspark.testing.utils import QuietTest
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
|
|
|
|
class UDFTests(ReusedSQLTestCase):
|
|
|
|
|
|
|
|
def test_udf_with_callable(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.spark.createDataFrame(rdd)
|
|
|
|
|
|
|
|
class PlusFour:
|
|
|
|
def __call__(self, col):
|
|
|
|
if col is not None:
|
|
|
|
return col + 4
|
|
|
|
|
|
|
|
call = PlusFour()
|
|
|
|
pudf = UserDefinedFunction(call, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
|
|
|
def test_udf_with_partial_function(self):
|
|
|
|
d = [Row(number=i, squared=i**2) for i in range(10)]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
data = self.spark.createDataFrame(rdd)
|
|
|
|
|
|
|
|
def some_func(col, param):
|
|
|
|
if col is not None:
|
|
|
|
return col + param
|
|
|
|
|
|
|
|
pfunc = functools.partial(some_func, param=4)
|
|
|
|
pudf = UserDefinedFunction(pfunc, LongType())
|
|
|
|
res = data.select(pudf(data['number']).alias('plus_four'))
|
|
|
|
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
|
|
|
|
|
|
|
|
def test_udf(self):
|
|
|
|
self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
|
|
|
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
|
|
|
|
sqlContext = self.spark._wrapped
|
|
|
|
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
|
|
|
|
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
|
|
|
|
self.assertEqual(row[0], 4)
|
|
|
|
|
|
|
|
def test_udf2(self):
|
|
|
|
with self.tempView("test"):
|
|
|
|
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
|
|
|
|
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
|
|
|
|
.createOrReplaceTempView("test")
|
|
|
|
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
|
|
|
|
self.assertEqual(4, res[0])
|
|
|
|
|
|
|
|
def test_udf3(self):
|
|
|
|
two_args = self.spark.catalog.registerFunction(
|
|
|
|
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
|
|
|
|
self.assertEqual(two_args.deterministic, True)
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], u'5')
|
|
|
|
|
|
|
|
def test_udf_registration_return_type_none(self):
|
|
|
|
two_args = self.spark.catalog.registerFunction(
|
|
|
|
"twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
|
|
|
|
self.assertEqual(two_args.deterministic, True)
|
|
|
|
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
|
|
|
self.assertEqual(row[0], 5)
|
|
|
|
|
|
|
|
def test_udf_registration_return_type_not_none(self):
|
|
|
|
with QuietTest(self.sc):
|
[SPARK-28264][PYTHON][SQL] Support type hints in pandas UDF and rename/move inconsistent pandas UDF types
### What changes were proposed in this pull request?
This PR proposes to redesign pandas UDFs as described in [the proposal](https://docs.google.com/document/d/1-kV0FS_LF2zvaRh_GhkV32Uqksm_Sq8SvnBBmRyxm30/edit?usp=sharing).
```python
from pyspark.sql.functions import pandas_udf
import pandas as pd
pandas_udf("long")
def plug_one(s: pd.Series) -> pd.Series:
return s + 1
spark.range(10).select(plug_one("id")).show()
```
```
+------------+
|plug_one(id)|
+------------+
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
| 10|
+------------+
```
Note that, this PR address one of the future improvements described [here](https://docs.google.com/document/d/1-kV0FS_LF2zvaRh_GhkV32Uqksm_Sq8SvnBBmRyxm30/edit#heading=h.h3ncjpk6ujqu), "A couple of less-intuitive pandas UDF types" (by zero323) together.
In short,
- Adds new way with type hints as an alternative and experimental way.
```python
pandas_udf(schema='...')
def func(c1: Series, c2: Series) -> DataFrame:
pass
```
- Replace and/or add an alias for three types below from UDF, and make them as separate standalone APIs. So, `pandas_udf` is now consistent with regular `udf`s and other expressions.
`df.mapInPandas(udf)` -replace-> `df.mapInPandas(f, schema)`
`df.groupby.apply(udf)` -alias-> `df.groupby.applyInPandas(f, schema)`
`df.groupby.cogroup.apply(udf)` -replace-> `df.groupby.cogroup.applyInPandas(f, schema)`
*`df.groupby.apply` was added from 2.3 while the other were added in the master only.
- No deprecation for the existing ways for now.
```python
pandas_udf(schema='...', functionType=PandasUDFType.SCALAR)
def func(c1, c2):
pass
```
If users are happy with this, I plan to deprecate the existing way and declare using type hints is not experimental anymore.
One design goal in this PR was that, avoid touching the internal (since we didn't deprecate the old ways for now), but supports type hints with a minimised changes only at the interface.
- Once we deprecate or remove the old ways, I think it requires another refactoring for the internal in the future. At the very least, we should rename internal pandas evaluation types.
- If users find this experimental type hints isn't quite helpful, we should simply revert the changes at the interface level.
### Why are the changes needed?
In order to address old design issues. Please see [the proposal](https://docs.google.com/document/d/1-kV0FS_LF2zvaRh_GhkV32Uqksm_Sq8SvnBBmRyxm30/edit?usp=sharing).
### Does this PR introduce any user-facing change?
For behaviour changes, No.
It adds new ways to use pandas UDFs by using type hints. See below.
**SCALAR**:
```python
pandas_udf(schema='...')
def func(c1: Series, c2: DataFrame) -> Series:
pass # DataFrame represents a struct column
```
**SCALAR_ITER**:
```python
pandas_udf(schema='...')
def func(iter: Iterator[Tuple[Series, DataFrame, ...]]) -> Iterator[Series]:
pass # Same as SCALAR but wrapped by Iterator
```
**GROUPED_AGG**:
```python
pandas_udf(schema='...')
def func(c1: Series, c2: DataFrame) -> int:
pass # DataFrame represents a struct column
```
**GROUPED_MAP**:
This was added in Spark 2.3 as of SPARK-20396. As described above, it keeps the existing behaviour. Additionally, we now have a new alias `groupby.applyInPandas` for `groupby.apply`. See the example below:
```python
def func(pdf):
return pdf
df.groupby("...").applyInPandas(func, schema=df.schema)
```
**MAP_ITER**: this is not a pandas UDF anymore
This was added in Spark 3.0 as of SPARK-28198; and this PR replaces the usages. See the example below:
```python
def func(iter):
for df in iter:
yield df
df.mapInPandas(func, df.schema)
```
**COGROUPED_MAP**: this is not a pandas UDF anymore
This was added in Spark 3.0 as of SPARK-27463; and this PR replaces the usages. See the example below:
```python
def asof_join(left, right):
return pd.merge_asof(left, right, on="...", by="...")
df1.groupby("...").cogroup(df2.groupby("...")).applyInPandas(asof_join, schema="...")
```
### How was this patch tested?
Unittests added and tested against Python 2.7, 3.6 and 3.7.
Closes #27165 from HyukjinKwon/revisit-pandas.
Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-01-22 01:32:58 -05:00
|
|
|
with self.assertRaisesRegexp(TypeError, "Invalid return type"):
|
2018-11-14 01:51:11 -05:00
|
|
|
self.spark.catalog.registerFunction(
|
|
|
|
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
|
|
|
|
|
|
|
|
def test_nondeterministic_udf(self):
|
|
|
|
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
|
|
|
|
import random
|
|
|
|
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
|
|
|
|
self.assertEqual(udf_random_col.deterministic, False)
|
|
|
|
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
|
|
|
|
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
|
|
|
|
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
|
|
|
|
self.assertEqual(row[0] + 10, row[1])
|
|
|
|
|
|
|
|
def test_nondeterministic_udf2(self):
|
|
|
|
import random
|
|
|
|
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
|
|
|
|
self.assertEqual(random_udf.deterministic, False)
|
|
|
|
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
|
|
|
|
self.assertEqual(random_udf1.deterministic, False)
|
|
|
|
[row] = self.spark.sql("SELECT randInt()").collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
[row] = self.spark.range(1).select(random_udf1()).collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
[row] = self.spark.range(1).select(random_udf()).collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
# render_doc() reproduces the help() exception without printing output
|
|
|
|
pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
|
|
|
|
pydoc.render_doc(random_udf)
|
|
|
|
pydoc.render_doc(random_udf1)
|
|
|
|
pydoc.render_doc(udf(lambda x: x).asNondeterministic)
|
|
|
|
|
|
|
|
def test_nondeterministic_udf3(self):
|
|
|
|
# regression test for SPARK-23233
|
|
|
|
f = udf(lambda x: x)
|
|
|
|
# Here we cache the JVM UDF instance.
|
|
|
|
self.spark.range(1).select(f("id"))
|
|
|
|
# This should reset the cache to set the deterministic status correctly.
|
|
|
|
f = f.asNondeterministic()
|
|
|
|
# Check the deterministic status of udf.
|
|
|
|
df = self.spark.range(1).select(f("id"))
|
|
|
|
deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
|
|
|
|
self.assertFalse(deterministic)
|
|
|
|
|
|
|
|
def test_nondeterministic_udf_in_aggregate(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import sum
|
2018-11-14 01:51:11 -05:00
|
|
|
import random
|
|
|
|
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
|
|
|
|
df = self.spark.range(10)
|
|
|
|
|
|
|
|
with QuietTest(self.sc):
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
|
|
|
df.groupby('id').agg(sum(udf_random_col())).collect()
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
|
|
|
df.agg(sum(udf_random_col())).collect()
|
|
|
|
|
|
|
|
def test_chained_udf(self):
|
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1)").collect()
|
|
|
|
self.assertEqual(row[0], 2)
|
|
|
|
[row] = self.spark.sql("SELECT double(double(1))").collect()
|
|
|
|
self.assertEqual(row[0], 4)
|
|
|
|
[row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
|
|
|
|
self.assertEqual(row[0], 6)
|
|
|
|
|
|
|
|
def test_single_udf_with_repeated_argument(self):
|
|
|
|
# regression test for SPARK-20685
|
|
|
|
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
row = self.spark.sql("SELECT add(1, 1)").first()
|
|
|
|
self.assertEqual(tuple(row), (2, ))
|
|
|
|
|
|
|
|
def test_multiple_udfs(self):
|
|
|
|
self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(1), double(2)").collect()
|
|
|
|
self.assertEqual(tuple(row), (2, 4))
|
|
|
|
[row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
|
|
|
|
self.assertEqual(tuple(row), (4, 12))
|
|
|
|
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
|
|
|
|
[row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
|
|
|
|
self.assertEqual(tuple(row), (6, 5))
|
|
|
|
|
|
|
|
def test_udf_in_filter_on_top_of_outer_join(self):
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
df = left.join(right, on='a', how='left_outer')
|
|
|
|
df = df.withColumn('b', udf(lambda x: 'x')(df.a))
|
|
|
|
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
|
|
|
|
|
|
|
|
def test_udf_in_filter_on_top_of_join(self):
|
|
|
|
# regression test for SPARK-18589
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
|
|
|
df = left.crossJoin(right).filter(f("a", "b"))
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
|
|
|
|
|
|
|
def test_udf_in_join_condition(self):
|
|
|
|
# regression test for SPARK-25314
|
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
[SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition
## What changes were proposed in this pull request?
There is a bug in `ExtractPythonUDFs` that produces wrong result attributes. It causes a failure when using `PythonUDF`s among multiple child plans, e.g., join. An example is using `PythonUDF`s in join condition.
```python
>>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
>>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
>>> f = udf(lambda a: a, IntegerType())
>>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
>>> df.collect()
19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5)
java.lang.ArrayIndexOutOfBoundsException: 1
at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36)
at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195)
at org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70)
...
```
## How was this patch tested?
Added test.
Closes #25091 from viirya/SPARK-28323.
Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
2019-07-10 19:29:58 -04:00
|
|
|
# The udf uses attributes from both sides of join, so it is pulled out as Filter +
|
|
|
|
# Cross join.
|
2018-11-14 01:51:11 -05:00
|
|
|
df = left.join(right, f("a", "b"))
|
[SPARK-28621][SQL] Make spark.sql.crossJoin.enabled default value true
### What changes were proposed in this pull request?
Make `spark.sql.crossJoin.enabled` default value true
### Why are the changes needed?
For implicit cross join, we can set up a watchdog to cancel it if running for a long time.
When "spark.sql.crossJoin.enabled" is false, because `CheckCartesianProducts` is implemented in logical plan stage, it may generate some mismatching error which may confuse end user:
* it's done in logical phase, so we may fail queries that can be executed via broadcast join, which is very fast.
* if we move the check to the physical phase, then a query may success at the beginning, and begin to fail when the table size gets larger (other people insert data to the table). This can be quite confusing.
* the CROSS JOIN syntax doesn't work well if join reorder happens.
* some non-equi-join will generate plan using cartesian product, but `CheckCartesianProducts` do not detect it and raise error.
So that in order to address this in simpler way, we can turn off showing this cross-join error by default.
For reference, I list some cases raising mismatching error here:
Providing:
```
spark.range(2).createOrReplaceTempView("sm1") // can be broadcast
spark.range(50000000).createOrReplaceTempView("bg1") // cannot be broadcast
spark.range(60000000).createOrReplaceTempView("bg2") // cannot be broadcast
```
1) Some join could be convert to broadcast nested loop join, but CheckCartesianProducts raise error. e.g.
```
select sm1.id, bg1.id from bg1 join sm1 where sm1.id < bg1.id
```
2) Some join will run by CartesianJoin but CheckCartesianProducts DO NOT raise error. e.g.
```
select bg1.id, bg2.id from bg1 join bg2 where bg1.id < bg2.id
```
### Does this PR introduce any user-facing change?
### How was this patch tested?
Closes #25520 from WeichenXu123/SPARK-28621.
Authored-by: WeichenXu <weichen.xu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
2019-08-27 09:53:37 -04:00
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
|
|
|
|
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
|
|
|
|
df.collect()
|
2018-11-14 01:51:11 -05:00
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
|
|
|
|
2018-11-28 07:38:42 -05:00
|
|
|
def test_udf_in_left_outer_join_condition(self):
|
|
|
|
# regression test for SPARK-26147
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import col
|
2018-11-28 07:38:42 -05:00
|
|
|
left = self.spark.createDataFrame([Row(a=1)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1)])
|
|
|
|
f = udf(lambda a: str(a), StringType())
|
|
|
|
# The join condition can't be pushed down, as it refers to attributes from both sides.
|
|
|
|
# The Python UDF only refer to attributes from one side, so it's evaluable.
|
|
|
|
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
|
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
def test_udf_and_common_filter_in_join_condition(self):
|
|
|
|
# regression test for SPARK-25314
|
|
|
|
# test the complex scenario with both udf and common filter
|
|
|
|
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
|
|
|
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
|
|
|
|
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
|
|
|
|
|
|
|
|
def test_udf_not_supported_in_join_condition(self):
|
|
|
|
# regression test for SPARK-25314
|
[SPARK-26864][SQL] Query may return incorrect result when python udf is used as a join condition and the udf uses attributes from both legs of left semi join.
## What changes were proposed in this pull request?
In SPARK-25314, we supported the scenario of having a python UDF that refers to attributes from both legs of a join condition by rewriting the plan to convert an inner join or left semi join to a filter over a cross join. In case of left semi join, this transformation may cause incorrect results when the right leg of join condition produces duplicate rows based on the join condition. This fix disallows the rewrite for left semi join and raises an error in the case like we do for other types of join. In future, we should have separate rule in optimizer to convert left semi join to inner join (I am aware of one case we could do it if we leverage informational constraint i.e when we know the right side does not produce duplicates).
**Python**
```SQL
>>> from pyspark import SparkContext
>>> from pyspark.sql import SparkSession, Column, Row
>>> from pyspark.sql.functions import UserDefinedFunction, udf
>>> from pyspark.sql.types import *
>>> from pyspark.sql.utils import AnalysisException
>>>
>>> spark.conf.set("spark.sql.crossJoin.enabled", "True")
>>> left = spark.createDataFrame([Row(lc1=1, lc2=1), Row(lc1=2, lc2=2)])
>>> right = spark.createDataFrame([Row(rc1=1, rc2=1), Row(rc1=1, rc2=1)])
>>> func = udf(lambda a, b: a == b, BooleanType())
>>> df = left.join(right, func("lc1", "rc1"), "leftsemi").show()
19/02/12 16:07:10 WARN PullOutPythonUDFInJoinCondition: The join condition:<lambda>(lc1#0L, rc1#4L) of the join plan contains PythonUDF only, it will be moved out and the join plan will be turned to cross join.
+---+---+
|lc1|lc2|
+---+---+
| 1| 1|
| 1| 1|
+---+---+
```
**Scala**
```SQL
scala> val left = Seq((1, 1), (2, 2)).toDF("lc1", "lc2")
left: org.apache.spark.sql.DataFrame = [lc1: int, lc2: int]
scala> val right = Seq((1, 1), (1, 1)).toDF("rc1", "rc2")
right: org.apache.spark.sql.DataFrame = [rc1: int, rc2: int]
scala> val equal = udf((p1: Integer, p2: Integer) => {
| p1 == p2
| })
equal: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$2141/11016292394666f1b5,BooleanType,List(Some(Schema(IntegerType,true)), Some(Schema(IntegerType,true))),None,false,true)
scala> val df = left.join(right, equal(col("lc1"), col("rc1")), "leftsemi")
df: org.apache.spark.sql.DataFrame = [lc1: int, lc2: int]
scala> df.show()
+---+---+
|lc1|lc2|
+---+---+
| 1| 1|
+---+---+
```
## How was this patch tested?
Modified existing tests.
Closes #23769 from dilipbiswal/dkb_python_udf_in_join.
Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
2019-02-13 08:14:19 -05:00
|
|
|
# test python udf is not supported in join type except inner join.
|
2018-11-14 01:51:11 -05:00
|
|
|
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
|
|
|
|
f = udf(lambda a, b: a == b, BooleanType())
|
|
|
|
|
|
|
|
def runWithJoinType(join_type, type_string):
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
AnalysisException,
|
|
|
|
'Using PythonUDF.*%s is not supported.' % type_string):
|
|
|
|
left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
|
|
|
|
runWithJoinType("full", "FullOuter")
|
|
|
|
runWithJoinType("left", "LeftOuter")
|
|
|
|
runWithJoinType("right", "RightOuter")
|
|
|
|
runWithJoinType("leftanti", "LeftAnti")
|
[SPARK-26864][SQL] Query may return incorrect result when python udf is used as a join condition and the udf uses attributes from both legs of left semi join.
## What changes were proposed in this pull request?
In SPARK-25314, we supported the scenario of having a python UDF that refers to attributes from both legs of a join condition by rewriting the plan to convert an inner join or left semi join to a filter over a cross join. In case of left semi join, this transformation may cause incorrect results when the right leg of join condition produces duplicate rows based on the join condition. This fix disallows the rewrite for left semi join and raises an error in the case like we do for other types of join. In future, we should have separate rule in optimizer to convert left semi join to inner join (I am aware of one case we could do it if we leverage informational constraint i.e when we know the right side does not produce duplicates).
**Python**
```SQL
>>> from pyspark import SparkContext
>>> from pyspark.sql import SparkSession, Column, Row
>>> from pyspark.sql.functions import UserDefinedFunction, udf
>>> from pyspark.sql.types import *
>>> from pyspark.sql.utils import AnalysisException
>>>
>>> spark.conf.set("spark.sql.crossJoin.enabled", "True")
>>> left = spark.createDataFrame([Row(lc1=1, lc2=1), Row(lc1=2, lc2=2)])
>>> right = spark.createDataFrame([Row(rc1=1, rc2=1), Row(rc1=1, rc2=1)])
>>> func = udf(lambda a, b: a == b, BooleanType())
>>> df = left.join(right, func("lc1", "rc1"), "leftsemi").show()
19/02/12 16:07:10 WARN PullOutPythonUDFInJoinCondition: The join condition:<lambda>(lc1#0L, rc1#4L) of the join plan contains PythonUDF only, it will be moved out and the join plan will be turned to cross join.
+---+---+
|lc1|lc2|
+---+---+
| 1| 1|
| 1| 1|
+---+---+
```
**Scala**
```SQL
scala> val left = Seq((1, 1), (2, 2)).toDF("lc1", "lc2")
left: org.apache.spark.sql.DataFrame = [lc1: int, lc2: int]
scala> val right = Seq((1, 1), (1, 1)).toDF("rc1", "rc2")
right: org.apache.spark.sql.DataFrame = [rc1: int, rc2: int]
scala> val equal = udf((p1: Integer, p2: Integer) => {
| p1 == p2
| })
equal: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$2141/11016292394666f1b5,BooleanType,List(Some(Schema(IntegerType,true)), Some(Schema(IntegerType,true))),None,false,true)
scala> val df = left.join(right, equal(col("lc1"), col("rc1")), "leftsemi")
df: org.apache.spark.sql.DataFrame = [lc1: int, lc2: int]
scala> df.show()
+---+---+
|lc1|lc2|
+---+---+
| 1| 1|
+---+---+
```
## How was this patch tested?
Modified existing tests.
Closes #23769 from dilipbiswal/dkb_python_udf_in_join.
Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
2019-02-13 08:14:19 -05:00
|
|
|
runWithJoinType("leftsemi", "LeftSemi")
|
2018-11-14 01:51:11 -05:00
|
|
|
|
[SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition
## What changes were proposed in this pull request?
There is a bug in `ExtractPythonUDFs` that produces wrong result attributes. It causes a failure when using `PythonUDF`s among multiple child plans, e.g., join. An example is using `PythonUDF`s in join condition.
```python
>>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
>>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
>>> f = udf(lambda a: a, IntegerType())
>>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
>>> df.collect()
19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5)
java.lang.ArrayIndexOutOfBoundsException: 1
at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36)
at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36)
at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195)
at org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70)
...
```
## How was this patch tested?
Added test.
Closes #25091 from viirya/SPARK-28323.
Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
2019-07-10 19:29:58 -04:00
|
|
|
def test_udf_as_join_condition(self):
|
|
|
|
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
|
|
|
|
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
|
|
|
|
f = udf(lambda a: a, IntegerType())
|
|
|
|
|
|
|
|
df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
|
|
|
|
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
def test_udf_without_arguments(self):
|
|
|
|
self.spark.catalog.registerFunction("foo", lambda: "bar")
|
|
|
|
[row] = self.spark.sql("SELECT foo()").collect()
|
|
|
|
self.assertEqual(row[0], "bar")
|
|
|
|
|
|
|
|
def test_udf_with_array_type(self):
|
|
|
|
with self.tempView("test"):
|
|
|
|
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
|
|
|
|
rdd = self.sc.parallelize(d)
|
|
|
|
self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
|
|
|
|
self.spark.catalog.registerFunction(
|
|
|
|
"copylist", lambda l: list(l), ArrayType(IntegerType()))
|
|
|
|
self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
|
|
|
|
[(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
|
|
|
|
self.assertEqual(list(range(3)), l1)
|
|
|
|
self.assertEqual(1, l2)
|
|
|
|
|
|
|
|
def test_broadcast_in_udf(self):
|
|
|
|
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
|
|
|
foo = self.sc.broadcast(bar)
|
|
|
|
self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
|
|
|
|
[res] = self.spark.sql("SELECT MYUDF('c')").collect()
|
|
|
|
self.assertEqual("abc", res[0])
|
|
|
|
[res] = self.spark.sql("SELECT MYUDF('')").collect()
|
|
|
|
self.assertEqual("", res[0])
|
|
|
|
|
|
|
|
def test_udf_with_filter_function(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import col
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a < 2, BooleanType())
|
|
|
|
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1, value='1')])
|
|
|
|
|
|
|
|
def test_udf_with_aggregate_function(self):
|
|
|
|
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import col, sum
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.sql.types import BooleanType
|
|
|
|
|
|
|
|
my_filter = udf(lambda a: a == 1, BooleanType())
|
|
|
|
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
|
|
|
|
self.assertEqual(sel.collect(), [Row(key=1)])
|
|
|
|
|
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
my_add = udf(lambda a, b: int(a + b), IntegerType())
|
|
|
|
my_strlen = udf(lambda x: len(x), IntegerType())
|
|
|
|
sel = df.groupBy(my_copy(col("key")).alias("k"))\
|
|
|
|
.agg(sum(my_strlen(col("value"))).alias("s"))\
|
|
|
|
.select(my_add(col("k"), col("s")).alias("t"))
|
|
|
|
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
|
|
|
|
|
|
|
|
def test_udf_in_generate(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import explode
|
2018-11-14 01:51:11 -05:00
|
|
|
df = self.spark.range(5)
|
|
|
|
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
|
|
|
|
row = df.select(explode(f(*df))).groupBy().sum().first()
|
|
|
|
self.assertEqual(row[0], 10)
|
|
|
|
|
|
|
|
df = self.spark.range(3)
|
|
|
|
res = df.select("id", explode(f(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 1)
|
|
|
|
self.assertEqual(res[0][1], 0)
|
|
|
|
self.assertEqual(res[1][0], 2)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 2)
|
|
|
|
self.assertEqual(res[2][1], 1)
|
|
|
|
|
|
|
|
range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
|
|
|
|
res = df.select("id", explode(range_udf(df.id))).collect()
|
|
|
|
self.assertEqual(res[0][0], 0)
|
|
|
|
self.assertEqual(res[0][1], -1)
|
|
|
|
self.assertEqual(res[1][0], 0)
|
|
|
|
self.assertEqual(res[1][1], 0)
|
|
|
|
self.assertEqual(res[2][0], 1)
|
|
|
|
self.assertEqual(res[2][1], 0)
|
|
|
|
self.assertEqual(res[3][0], 1)
|
|
|
|
self.assertEqual(res[3][1], 1)
|
|
|
|
|
|
|
|
def test_udf_with_order_by_and_limit(self):
|
|
|
|
my_copy = udf(lambda x: x, IntegerType())
|
|
|
|
df = self.spark.range(10).orderBy("id")
|
|
|
|
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
|
|
|
|
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
|
|
|
|
|
|
|
|
def test_udf_registration_returns_udf(self):
|
|
|
|
df = self.spark.range(10)
|
|
|
|
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
df.selectExpr("add_three(id) AS plus_three").collect(),
|
|
|
|
df.select(add_three("id").alias("plus_three")).collect()
|
|
|
|
)
|
|
|
|
|
|
|
|
# This is to check if a 'SQLContext.udf' can call its alias.
|
|
|
|
sqlContext = self.spark._wrapped
|
|
|
|
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
df.selectExpr("add_four(id) AS plus_four").collect(),
|
|
|
|
df.select(add_four("id").alias("plus_four")).collect()
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_non_existed_udf(self):
|
|
|
|
spark = self.spark
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
|
|
|
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
|
|
|
|
|
|
|
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
|
|
|
|
sqlContext = spark._wrapped
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
|
|
|
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
|
|
|
|
|
|
|
|
def test_non_existed_udaf(self):
|
|
|
|
spark = self.spark
|
|
|
|
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
|
|
|
|
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
|
|
|
|
|
|
|
|
def test_udf_with_input_file_name(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import input_file_name
|
2018-11-14 01:51:11 -05:00
|
|
|
sourceFile = udf(lambda path: path, StringType())
|
|
|
|
filePath = "python/test_support/sql/people1.json"
|
|
|
|
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
|
|
|
|
self.assertTrue(row[0].find("people1.json") != -1)
|
|
|
|
|
|
|
|
def test_udf_with_input_file_name_for_hadooprdd(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import input_file_name
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
def filename(path):
|
|
|
|
return path
|
|
|
|
|
|
|
|
sameText = udf(filename, StringType())
|
|
|
|
|
|
|
|
rdd = self.sc.textFile('python/test_support/sql/people.json')
|
|
|
|
df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
|
|
|
|
row = df.select(sameText(df['file'])).first()
|
|
|
|
self.assertTrue(row[0].find("people.json") != -1)
|
|
|
|
|
|
|
|
rdd2 = self.sc.newAPIHadoopFile(
|
|
|
|
'python/test_support/sql/people.json',
|
|
|
|
'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
|
|
|
|
'org.apache.hadoop.io.LongWritable',
|
|
|
|
'org.apache.hadoop.io.Text')
|
|
|
|
|
|
|
|
df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
|
|
|
|
row2 = df2.select(sameText(df2['file'])).first()
|
|
|
|
self.assertTrue(row2[0].find("people.json") != -1)
|
|
|
|
|
|
|
|
def test_udf_defers_judf_initialization(self):
|
|
|
|
# This is separate of UDFInitializationTests
|
|
|
|
# to avoid context initialization
|
|
|
|
# when udf is called
|
|
|
|
f = UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should not be initialized before the first call."
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
|
|
|
|
|
|
|
|
self.assertIsNotNone(
|
|
|
|
f._judf_placeholder,
|
|
|
|
"judf should be initialized after UDF has been called."
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_udf_with_string_return_type(self):
|
|
|
|
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
|
|
|
|
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
|
|
|
|
make_array = UserDefinedFunction(
|
|
|
|
lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
|
|
|
|
|
|
|
|
expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
|
|
|
|
actual = (self.spark.range(1, 2).toDF("x")
|
|
|
|
.select(add_one("x"), make_pair("x"), make_array("x"))
|
|
|
|
.first())
|
|
|
|
|
|
|
|
self.assertTupleEqual(expected, actual)
|
|
|
|
|
|
|
|
def test_udf_shouldnt_accept_noncallable_object(self):
|
|
|
|
non_callable = None
|
|
|
|
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
|
|
|
|
|
|
|
|
def test_udf_with_decorator(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import lit
|
2018-11-14 01:51:11 -05:00
|
|
|
from pyspark.sql.types import IntegerType, DoubleType
|
|
|
|
|
|
|
|
@udf(IntegerType())
|
|
|
|
def add_one(x):
|
|
|
|
if x is not None:
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
@udf(returnType=DoubleType())
|
|
|
|
def add_two(x):
|
|
|
|
if x is not None:
|
|
|
|
return float(x + 2)
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def to_upper(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.upper()
|
|
|
|
|
|
|
|
@udf()
|
|
|
|
def to_lower(x):
|
|
|
|
if x is not None:
|
|
|
|
return x.lower()
|
|
|
|
|
|
|
|
@udf
|
|
|
|
def substr(x, start, end):
|
|
|
|
if x is not None:
|
|
|
|
return x[start:end]
|
|
|
|
|
|
|
|
@udf("long")
|
|
|
|
def trunc(x):
|
|
|
|
return int(x)
|
|
|
|
|
|
|
|
@udf(returnType="double")
|
|
|
|
def as_double(x):
|
|
|
|
return float(x)
|
|
|
|
|
|
|
|
df = (
|
|
|
|
self.spark
|
|
|
|
.createDataFrame(
|
|
|
|
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
|
|
|
|
.select(
|
|
|
|
add_one("one"), add_two("one"),
|
|
|
|
to_upper("Foo"), to_lower("Foo"),
|
|
|
|
substr("foobar", lit(0), lit(3)),
|
|
|
|
trunc("float"), as_double("one")))
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
[tpe for _, tpe in df.dtypes],
|
|
|
|
["int", "double", "string", "string", "string", "bigint", "double"]
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertListEqual(
|
|
|
|
list(df.first()),
|
|
|
|
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_udf_wrapper(self):
|
|
|
|
from pyspark.sql.types import IntegerType
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
"""Identity"""
|
|
|
|
return x
|
|
|
|
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
|
|
|
class F(object):
|
|
|
|
"""Identity"""
|
|
|
|
def __call__(self, x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
f = F()
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
|
|
|
f = functools.partial(f, x=1)
|
|
|
|
return_type = IntegerType()
|
|
|
|
f_ = udf(f, return_type)
|
|
|
|
|
|
|
|
self.assertTrue(f.__doc__ in f_.__doc__)
|
|
|
|
self.assertEqual(f, f_.func)
|
|
|
|
self.assertEqual(return_type, f_.returnType)
|
|
|
|
|
|
|
|
def test_nonparam_udf_with_aggregate(self):
|
|
|
|
import pyspark.sql.functions as f
|
|
|
|
|
|
|
|
df = self.spark.createDataFrame([(1, 2), (1, 2)])
|
|
|
|
f_udf = f.udf(lambda: "const_str")
|
|
|
|
rows = df.distinct().withColumn("a", f_udf()).collect()
|
|
|
|
self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
|
|
|
|
|
|
|
|
# SPARK-24721
|
|
|
|
@unittest.skipIf(not test_compiled, test_not_compiled_message)
|
|
|
|
def test_datasource_with_udf(self):
|
2018-12-11 01:16:51 -05:00
|
|
|
from pyspark.sql.functions import lit, col
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
path = tempfile.mkdtemp()
|
|
|
|
shutil.rmtree(path)
|
|
|
|
|
|
|
|
try:
|
|
|
|
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
|
|
|
|
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
|
|
|
|
datasource_df = self.spark.read \
|
|
|
|
.format("org.apache.spark.sql.sources.SimpleScanSource") \
|
|
|
|
.option('from', 0).option('to', 1).load().toDF('i')
|
|
|
|
datasource_v2_df = self.spark.read \
|
2019-09-12 07:59:34 -04:00
|
|
|
.format("org.apache.spark.sql.connector.SimpleDataSourceV2") \
|
2018-11-14 01:51:11 -05:00
|
|
|
.load().toDF('i', 'j')
|
|
|
|
|
|
|
|
c1 = udf(lambda x: x + 1, 'int')(lit(1))
|
|
|
|
c2 = udf(lambda x: x + 1, 'int')(col('i'))
|
|
|
|
|
|
|
|
f1 = udf(lambda x: False, 'boolean')(lit(1))
|
|
|
|
f2 = udf(lambda x: False, 'boolean')(col('i'))
|
|
|
|
|
|
|
|
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
|
|
|
result = df.withColumn('c', c1)
|
|
|
|
expected = df.withColumn('c', lit(2))
|
|
|
|
self.assertEquals(expected.collect(), result.collect())
|
|
|
|
|
|
|
|
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
|
|
|
result = df.withColumn('c', c2)
|
|
|
|
expected = df.withColumn('c', col('i') + 1)
|
|
|
|
self.assertEquals(expected.collect(), result.collect())
|
|
|
|
|
|
|
|
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
|
|
|
for f in [f1, f2]:
|
|
|
|
result = df.filter(f)
|
|
|
|
self.assertEquals(0, result.count())
|
|
|
|
finally:
|
|
|
|
shutil.rmtree(path)
|
|
|
|
|
|
|
|
# SPARK-25591
|
|
|
|
def test_same_accumulator_in_udfs(self):
|
|
|
|
data_schema = StructType([StructField("a", IntegerType(), True),
|
|
|
|
StructField("b", IntegerType(), True)])
|
|
|
|
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
|
|
|
|
|
|
|
|
test_accum = self.sc.accumulator(0)
|
|
|
|
|
|
|
|
def first_udf(x):
|
|
|
|
test_accum.add(1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def second_udf(x):
|
|
|
|
test_accum.add(100)
|
|
|
|
return x
|
|
|
|
|
|
|
|
func_udf = udf(first_udf, IntegerType())
|
|
|
|
func_udf2 = udf(second_udf, IntegerType())
|
|
|
|
data = data.withColumn("out1", func_udf(data["a"]))
|
|
|
|
data = data.withColumn("out2", func_udf2(data["b"]))
|
|
|
|
data.collect()
|
|
|
|
self.assertEqual(test_accum.value, 101)
|
|
|
|
|
2018-12-11 01:16:51 -05:00
|
|
|
# SPARK-26293
|
|
|
|
def test_udf_in_subquery(self):
|
|
|
|
f = udf(lambda x: x, "long")
|
|
|
|
with self.tempView("v"):
|
|
|
|
self.spark.range(1).filter(f("id") >= 0).createTempView("v")
|
|
|
|
sql = self.spark.sql
|
|
|
|
result = sql("select i from values(0L) as data(i) where i in (select id from v)")
|
|
|
|
self.assertEqual(result.collect(), [Row(i=0)])
|
|
|
|
|
2019-03-12 11:23:26 -04:00
|
|
|
def test_udf_globals_not_overwritten(self):
|
|
|
|
@udf('string')
|
|
|
|
def f():
|
|
|
|
assert "itertools" not in str(map)
|
|
|
|
|
|
|
|
self.spark.range(1).select(f()).collect()
|
|
|
|
|
2019-07-30 20:10:24 -04:00
|
|
|
def test_worker_original_stdin_closed(self):
|
|
|
|
# Test if it closes the original standard input of worker inherited from the daemon,
|
|
|
|
# and replaces it with '/dev/null'. See SPARK-26175.
|
|
|
|
def task(iterator):
|
|
|
|
import sys
|
|
|
|
res = sys.stdin.read()
|
|
|
|
# Because the standard input is '/dev/null', it reaches to EOF.
|
|
|
|
assert res == '', "Expect read EOF from stdin."
|
|
|
|
return iterator
|
|
|
|
|
|
|
|
self.sc.parallelize(range(1), 1).mapPartitions(task).count()
|
|
|
|
|
2019-11-08 22:19:14 -05:00
|
|
|
def test_udf_with_256_args(self):
|
|
|
|
N = 256
|
|
|
|
data = [["data-%d" % i for i in range(N)]] * 5
|
|
|
|
df = self.spark.createDataFrame(data)
|
|
|
|
|
|
|
|
def f(*a):
|
|
|
|
return "success"
|
|
|
|
|
|
|
|
fUdf = udf(f, StringType())
|
|
|
|
|
|
|
|
r = df.select(fUdf(*df.columns))
|
|
|
|
self.assertEqual(r.first()[0], "success")
|
|
|
|
|
2018-11-14 01:51:11 -05:00
|
|
|
|
|
|
|
class UDFInitializationTests(unittest.TestCase):
|
|
|
|
def tearDown(self):
|
|
|
|
if SparkSession._instantiatedSession is not None:
|
|
|
|
SparkSession._instantiatedSession.stop()
|
|
|
|
|
|
|
|
if SparkContext._active_spark_context is not None:
|
|
|
|
SparkContext._active_spark_context.stop()
|
|
|
|
|
|
|
|
def test_udf_init_shouldnt_initialize_context(self):
|
|
|
|
UserDefinedFunction(lambda x: x, StringType())
|
|
|
|
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkContext._active_spark_context,
|
|
|
|
"SparkContext shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
self.assertIsNone(
|
|
|
|
SparkSession._instantiatedSession,
|
|
|
|
"SparkSession shouldn't be initialized when UserDefinedFunction is created."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
from pyspark.sql.tests.test_udf import *
|
|
|
|
|
|
|
|
try:
|
|
|
|
import xmlrunner
|
2019-06-23 20:58:17 -04:00
|
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
2018-11-14 01:51:11 -05:00
|
|
|
except ImportError:
|
2018-11-14 23:30:52 -05:00
|
|
|
testRunner = None
|
|
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|