[SPARK-30681][PYSPARK][SQL] Add higher order functions API to PySpark

### What changes were proposed in this pull request?

This PR add Python API for invoking following higher functions:

- `transform`
- `exists`
- `forall`
- `filter`
- `aggregate`
- `zip_with`
- `transform_keys`
- `transform_values`
- `map_filter`
- `map_zip_with`

to `pyspark.sql`. Each of these accepts plain Python functions of one of the following types

- `(Column) -> Column: ...`
- `(Column, Column) -> Column: ...`
- `(Column, Column, Column) -> Column: ...`

Internally this proposal piggbacks on objects supporting Scala implementation ([SPARK-27297](https://issues.apache.org/jira/browse/SPARK-27297)) by:

1. Creating  required `UnresolvedNamedLambdaVariables`  exposing these as PySpark `Columns`
2. Invoking Python function with these columns as arguments.
3. Using the result, and underlying JVM objects from 1., to create `expressions.LambdaFunction` which is passed to desired expression, and repacked as Python `Column`.

### Why are the changes needed?

Currently higher order functions are available only using SQL and Scala API and can use only SQL expressions

```python
df.selectExpr("transform(values, x -> x + 1)")
```

This works reasonably well for simple functions, but can get really ugly with complex functions (complex functions, casts), resulting objects are somewhat verbose and we don't get any IDE support.  Additionally DSL used, though  very simple, is not documented.

With changes propose here, above query could be rewritten as:

```python
df.select(transform("values", lambda x: x + 1))
```

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

- For positive cases this PR adds doctest strings covering possible usage patterns.
- For negative cases (unsupported function types) this PR adds unit tests.

### Notes

If approved, the same approach can be used in SparkR.

Closes #27406 from zero323/SPARK-30681.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
zero323 2020-02-28 12:59:39 +09:00 committed by HyukjinKwon
parent c467961e8a
commit 7de33f56e8
2 changed files with 480 additions and 0 deletions

View file

@ -2843,6 +2843,463 @@ def from_csv(col, schema, options={}):
return Column(jc) return Column(jc)
def _unresolved_named_lambda_variable(*name_parts):
"""
Create `o.a.s.sql.expressions.UnresolvedNamedLambdaVariable`,
convert it to o.s.sql.Column and wrap in Python `Column`
:param name_parts: str
"""
sc = SparkContext._active_spark_context
name_parts_seq = _to_seq(sc, name_parts)
expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
return Column(
sc._jvm.Column(
expressions.UnresolvedNamedLambdaVariable(name_parts_seq)
)
)
def _get_lambda_parameters(f):
import inspect
signature = inspect.signature(f)
parameters = signature.parameters.values()
# We should exclude functions that use
# variable args and keyword argnames
# as well as keyword only args
supported_parmeter_types = {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
}
# Validate that
# function arity is between 1 and 3
if not (1 <= len(parameters) <= 3):
raise ValueError(
"f should take between 1 and 3 arguments, but provided function takes {}".format(
len(parameters)
)
)
# and all arguments can be used as positional
if not all(p.kind in supported_parmeter_types for p in parameters):
raise ValueError(
"f should use only POSITIONAL or POSITIONAL OR KEYWORD arguments"
)
return parameters
def _get_lambda_parameters_legacy(f):
# TODO (SPARK-29909) Remove once 2.7 support is dropped
import inspect
spec = inspect.getargspec(f)
if not 1 <= len(spec.args) <= 3 or spec.varargs or spec.keywords:
raise ValueError(
"f should take between 1 and 3 arguments, but provided function takes {}".format(
spec
)
)
return spec.args
def _create_lambda(f):
"""
Create `o.a.s.sql.expressions.LambdaFunction` corresponding
to transformation described by f
:param f: A Python of one of the following forms:
- (Column) -> Column: ...
- (Column, Column) -> Column: ...
- (Column, Column, Column) -> Column: ...
"""
if sys.version_info >= (3, 3):
parameters = _get_lambda_parameters(f)
else:
parameters = _get_lambda_parameters_legacy(f)
sc = SparkContext._active_spark_context
expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
argnames = ["x", "y", "z"]
args = [
_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]
]
result = f(*args)
if not isinstance(result, Column):
raise ValueError("f should return Column, got {}".format(type(result)))
jexpr = result._jc.expr()
jargs = _to_seq(sc, [arg._jc.expr() for arg in args])
return expressions.LambdaFunction(jexpr, jargs, False)
def _invoke_higher_order_function(name, cols, funs):
"""
Invokes expression identified by name,
(relative to ```org.apache.spark.sql.catalyst.expressions``)
and wraps the result with Column (first Scala one, then Python).
:param name: Name of the expression
:param cols: a list of columns
:param funs: a list of((*Column) -> Column functions.
:return: a Column
"""
sc = SparkContext._active_spark_context
expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions
expr = getattr(expressions, name)
jcols = [_to_java_column(col).expr() for col in cols]
jfuns = [_create_lambda(f) for f in funs]
return Column(sc._jvm.Column(expr(*jcols + jfuns)))
@since(3.1)
def transform(col, f):
"""
Returns an array of elements after applying a transformation to each element in the input array.
:param col: name of column or expression
:param f: a function that is applied to each element of the input array.
Can take one of the following forms:
- Unary ``(x: Column) -> Column: ...``
- Binary ``(x: Column, i: Column) -> Column...``, where the second argument is
a 0-based index of the element.
and can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values"))
>>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show()
+------------+
| doubled|
+------------+
|[2, 4, 6, 8]|
+------------+
>>> def alternate(x, i):
... return when(i % 2 == 0, x).otherwise(-x)
>>> df.select(transform("values", alternate).alias("alternated")).show()
+--------------+
| alternated|
+--------------+
|[1, -2, 3, -4]|
+--------------+
"""
return _invoke_higher_order_function("ArrayTransform", [col], [f])
@since(3.1)
def exists(col, f):
"""
Returns whether a predicate holds for one or more elements in the array.
:param col: name of column or expression
:param f: an function ``(x: Column) -> Column: ...`` returning the Boolean expression.
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values"))
>>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()
+------------+
|any_negative|
+------------+
| false|
| true|
+------------+
"""
return _invoke_higher_order_function("ArrayExists", [col], [f])
@since(3.1)
def forall(col, f):
"""
Returns whether a predicate holds for every element in the array.
:param col: name of column or expression
:param f: an function ``(x: Column) -> Column: ...`` returning the Boolean expression.
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame(
... [(1, ["bar"]), (2, ["foo", "bar"]), (3, ["foobar", "foo"])],
... ("key", "values")
... )
>>> df.select(forall("values", lambda x: x.rlike("foo")).alias("all_foo")).show()
+-------+
|all_foo|
+-------+
| false|
| false|
| true|
+-------+
"""
return _invoke_higher_order_function("ArrayForAll", [col], [f])
@since(3.1)
def filter(col, f):
"""
Returns an array of elements for which a predicate holds in a given array.
:param col: name of column or expression
:param f: A function that returns the Boolean expression.
Can take one of the following forms:
- Unary ``(x: Column) -> Column: ...``
- Binary ``(x: Column, i: Column) -> Column...``, where the second argument is
a 0-based index of the element.
and can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame(
... [(1, ["2018-09-20", "2019-02-03", "2019-07-01", "2020-06-01"])],
... ("key", "values")
... )
>>> def after_second_quarter(x):
... return month(to_date(x)) > 6
>>> df.select(
... filter("values", after_second_quarter).alias("after_second_quarter")
... ).show(truncate=False)
+------------------------+
|after_second_quarter |
+------------------------+
|[2018-09-20, 2019-07-01]|
+------------------------+
"""
return _invoke_higher_order_function("ArrayFilter", [col], [f])
@since(3.1)
def aggregate(col, zero, merge, finish=None):
"""
Applies a binary operator to an initial state and all elements in the array,
and reduces this to a single state. The final state is converted into the final result
by applying a finish function.
Both functions can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:param col: name of column or expression
:param zero: initial value. Name of column or expression
:param merge: a binary function ``(acc: Column, x: Column) -> Column...`` returning expression
of the same type as ``zero``
:param finish: an optional unary function ``(x: Column) -> Column: ...``
used to convert accumulated value.
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values"))
>>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).show()
+----+
| sum|
+----+
|42.0|
+----+
>>> def merge(acc, x):
... count = acc.count + 1
... sum = acc.sum + x
... return struct(count.alias("count"), sum.alias("sum"))
>>> df.select(
... aggregate(
... "values",
... struct(lit(0).alias("count"), lit(0.0).alias("sum")),
... merge,
... lambda acc: acc.sum / acc.count,
... ).alias("mean")
... ).show()
+----+
|mean|
+----+
| 8.4|
+----+
"""
if finish is not None:
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[merge, finish]
)
else:
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[merge]
)
@since(3.1)
def zip_with(col1, col2, f):
"""
Merge two given arrays, element-wise, into a single array using a function.
If one array is shorter, nulls are appended at the end to match the length of the longer
array, before applying the function.
:param col1: name of the first column or expression
:param col2: name of the second column or expression
:param f: a binary function ``(x1: Column, x2: Column) -> Column...``
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, [1, 3, 5, 8], [0, 2, 4, 6])], ("id", "xs", "ys"))
>>> df.select(zip_with("xs", "ys", lambda x, y: x ** y).alias("powers")).show(truncate=False)
+---------------------------+
|powers |
+---------------------------+
|[1.0, 9.0, 625.0, 262144.0]|
+---------------------------+
>>> df = spark.createDataFrame([(1, ["foo", "bar"], [1, 2, 3])], ("id", "xs", "ys"))
>>> df.select(zip_with("xs", "ys", lambda x, y: concat_ws("_", x, y)).alias("xs_ys")).show()
+-----------------+
| xs_ys|
+-----------------+
|[foo_1, bar_2, 3]|
+-----------------+
"""
return _invoke_higher_order_function("ZipWith", [col1, col2], [f])
@since(3.1)
def transform_keys(col, f):
"""
Applies a function to every key-value pair in a map and returns
a map with the results of those applications as the new keys for the pairs.
:param col: name of column or expression
:param f: a binary function ``(k: Column, v: Column) -> Column...``
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data"))
>>> df.select(transform_keys(
... "data", lambda k, _: upper(k)).alias("data_upper")
... ).show(truncate=False)
+-------------------------+
|data_upper |
+-------------------------+
|[BAR -> 2.0, FOO -> -2.0]|
+-------------------------+
"""
return _invoke_higher_order_function("TransformKeys", [col], [f])
@since(3.1)
def transform_values(col, f):
"""
Applies a function to every key-value pair in a map and returns
a map with the results of those applications as the new values for the pairs.
:param col: name of column or expression
:param f: a binary function ``(k: Column, v: Column) -> Column...``
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, {"IT": 10.0, "SALES": 2.0, "OPS": 24.0})], ("id", "data"))
>>> df.select(transform_values(
... "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v)
... ).alias("new_data")).show(truncate=False)
+---------------------------------------+
|new_data |
+---------------------------------------+
|[OPS -> 34.0, IT -> 20.0, SALES -> 2.0]|
+---------------------------------------+
"""
return _invoke_higher_order_function("TransformValues", [col], [f])
@since(3.1)
def map_filter(col, f):
"""
Returns a map whose key-value pairs satisfy a predicate.
:param col: name of column or expression
:param f: a binary function ``(k: Column, v: Column) -> Column...``
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([(1, {"foo": 42.0, "bar": 1.0, "baz": 32.0})], ("id", "data"))
>>> df.select(map_filter(
... "data", lambda _, v: v > 30.0).alias("data_filtered")
... ).show(truncate=False)
+--------------------------+
|data_filtered |
+--------------------------+
|[baz -> 32.0, foo -> 42.0]|
+--------------------------+
"""
return _invoke_higher_order_function("MapFilter", [col], [f])
@since(3.1)
def map_zip_with(col1, col2, f):
"""
Merge two given maps, key-wise into a single map using a function.
:param col1: name of the first column or expression
:param col2: name of the second column or expression
:param f: a ternary function ``(k: Column, v1: Column, v2: Column) -> Column...``
Can use methods of :class:`pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
:return: a :class:`pyspark.sql.Column`
>>> df = spark.createDataFrame([
... (1, {"IT": 24.0, "SALES": 12.00}, {"IT": 2.0, "SALES": 1.4})],
... ("id", "base", "ratio")
... )
>>> df.select(map_zip_with(
... "base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data")
... ).show(truncate=False)
+---------------------------+
|updated_data |
+---------------------------+
|[SALES -> 16.8, IT -> 48.0]|
+---------------------------+
"""
return _invoke_higher_order_function("MapZipWith", [col1, col2], [f])
# ---------------------------- User Defined Function ---------------------------------- # ---------------------------- User Defined Function ----------------------------------
@since(1.3) @since(1.3)

View file

@ -337,6 +337,29 @@ class FunctionsTests(ReusedSQLTestCase):
self.assertListEqual(actual, expected) self.assertListEqual(actual, expected)
def test_higher_order_function_failures(self):
from pyspark.sql.functions import col, exists, transform
# Should fail with varargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda *x: lit(1))
# Should fail with kwargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda **x: lit(1))
# Should fail with nullary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda: lit(1))
# Should fail with quaternary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda x1, x2, x3, x4: lit(1))
# Should fail if function doesn't return Column
with self.assertRaises(ValueError):
transform(col("foo"), lambda x: 1)
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest