[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF
## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin <ice.xelloss@gmail.com> Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.
This commit is contained in:
parent
d6632d185e
commit
2cb23a8f51
|
@ -250,6 +250,15 @@ class ArrowStreamPandasSerializer(Serializer):
|
||||||
super(ArrowStreamPandasSerializer, self).__init__()
|
super(ArrowStreamPandasSerializer, self).__init__()
|
||||||
self._timezone = timezone
|
self._timezone = timezone
|
||||||
|
|
||||||
|
def arrow_to_pandas(self, arrow_column):
|
||||||
|
from pyspark.sql.types import from_arrow_type, \
|
||||||
|
_check_series_convert_date, _check_series_localize_timestamps
|
||||||
|
|
||||||
|
s = arrow_column.to_pandas()
|
||||||
|
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
|
||||||
|
s = _check_series_localize_timestamps(s, self._timezone)
|
||||||
|
return s
|
||||||
|
|
||||||
def dump_stream(self, iterator, stream):
|
def dump_stream(self, iterator, stream):
|
||||||
"""
|
"""
|
||||||
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
|
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
|
||||||
|
@ -272,16 +281,11 @@ class ArrowStreamPandasSerializer(Serializer):
|
||||||
"""
|
"""
|
||||||
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
||||||
"""
|
"""
|
||||||
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
|
|
||||||
_check_dataframe_localize_timestamps
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
reader = pa.open_stream(stream)
|
reader = pa.open_stream(stream)
|
||||||
schema = from_arrow_schema(reader.schema)
|
|
||||||
for batch in reader:
|
for batch in reader:
|
||||||
pdf = batch.to_pandas()
|
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
|
||||||
pdf = _check_dataframe_convert_date(pdf, schema)
|
|
||||||
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
|
|
||||||
yield [c for _, c in pdf.iteritems()]
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "ArrowStreamPandasSerializer"
|
return "ArrowStreamPandasSerializer"
|
||||||
|
|
|
@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
||||||
| 2| 1.1094003924504583|
|
| 2| 1.1094003924504583|
|
||||||
+---+-------------------+
|
+---+-------------------+
|
||||||
|
|
||||||
|
Alternatively, the user can define a function that takes two arguments.
|
||||||
|
In this case, the grouping key will be passed as the first argument and the data will
|
||||||
|
be passed as the second argument. The grouping key will be passed as a tuple of numpy
|
||||||
|
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
|
||||||
|
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
|
||||||
|
This is useful when the user does not want to hardcode grouping key in the function.
|
||||||
|
|
||||||
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||||
|
>>> import pandas as pd # doctest: +SKIP
|
||||||
|
>>> df = spark.createDataFrame(
|
||||||
|
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||||
|
... ("id", "v")) # doctest: +SKIP
|
||||||
|
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||||
|
... def mean_udf(key, pdf):
|
||||||
|
... # key is a tuple of one numpy.int64, which is the value
|
||||||
|
... # of 'id' for the current group
|
||||||
|
... return pd.DataFrame([key + (pdf.v.mean(),)])
|
||||||
|
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
|
||||||
|
+---+---+
|
||||||
|
| id| v|
|
||||||
|
+---+---+
|
||||||
|
| 1|1.5|
|
||||||
|
| 2|6.0|
|
||||||
|
+---+---+
|
||||||
|
|
||||||
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
|
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
|
||||||
|
|
||||||
3. GROUPED_AGG
|
3. GROUPED_AGG
|
||||||
|
|
|
@ -3903,7 +3903,7 @@ class PandasUDFTests(ReusedSQLTestCase):
|
||||||
return df
|
return df
|
||||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||||
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
|
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
|
||||||
def foo(k, v):
|
def foo(k, v, w):
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
@ -4476,20 +4476,45 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
||||||
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
|
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
|
||||||
df = self.data.withColumn("arr", array(col("id")))
|
df = self.data.withColumn("arr", array(col("id")))
|
||||||
|
|
||||||
foo_udf = pandas_udf(
|
# Different forms of group map pandas UDF, results of these are the same
|
||||||
|
|
||||||
|
output_schema = StructType(
|
||||||
|
[StructField('id', LongType()),
|
||||||
|
StructField('v', IntegerType()),
|
||||||
|
StructField('arr', ArrayType(LongType())),
|
||||||
|
StructField('v1', DoubleType()),
|
||||||
|
StructField('v2', LongType())])
|
||||||
|
|
||||||
|
udf1 = pandas_udf(
|
||||||
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||||
StructType(
|
output_schema,
|
||||||
[StructField('id', LongType()),
|
|
||||||
StructField('v', IntegerType()),
|
|
||||||
StructField('arr', ArrayType(LongType())),
|
|
||||||
StructField('v1', DoubleType()),
|
|
||||||
StructField('v2', LongType())]),
|
|
||||||
PandasUDFType.GROUPED_MAP
|
PandasUDFType.GROUPED_MAP
|
||||||
)
|
)
|
||||||
|
|
||||||
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
|
udf2 = pandas_udf(
|
||||||
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
|
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||||
self.assertPandasEqual(expected, result)
|
output_schema,
|
||||||
|
PandasUDFType.GROUPED_MAP
|
||||||
|
)
|
||||||
|
|
||||||
|
udf3 = pandas_udf(
|
||||||
|
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||||
|
output_schema,
|
||||||
|
PandasUDFType.GROUPED_MAP
|
||||||
|
)
|
||||||
|
|
||||||
|
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
|
||||||
|
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
|
||||||
|
|
||||||
|
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
|
||||||
|
expected2 = expected1
|
||||||
|
|
||||||
|
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
|
||||||
|
expected3 = expected1
|
||||||
|
|
||||||
|
self.assertPandasEqual(expected1, result1)
|
||||||
|
self.assertPandasEqual(expected2, result2)
|
||||||
|
self.assertPandasEqual(expected3, result3)
|
||||||
|
|
||||||
def test_register_grouped_map_udf(self):
|
def test_register_grouped_map_udf(self):
|
||||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||||
|
@ -4648,6 +4673,80 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
||||||
result = df.groupby('time').apply(foo_udf).sort('time')
|
result = df.groupby('time').apply(foo_udf).sort('time')
|
||||||
self.assertPandasEqual(df.toPandas(), result.toPandas())
|
self.assertPandasEqual(df.toPandas(), result.toPandas())
|
||||||
|
|
||||||
|
def test_udf_with_key(self):
|
||||||
|
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
|
||||||
|
df = self.data
|
||||||
|
pdf = df.toPandas()
|
||||||
|
|
||||||
|
def foo1(key, pdf):
|
||||||
|
import numpy as np
|
||||||
|
assert type(key) == tuple
|
||||||
|
assert type(key[0]) == np.int64
|
||||||
|
|
||||||
|
return pdf.assign(v1=key[0],
|
||||||
|
v2=pdf.v * key[0],
|
||||||
|
v3=pdf.v * pdf.id,
|
||||||
|
v4=pdf.v * pdf.id.mean())
|
||||||
|
|
||||||
|
def foo2(key, pdf):
|
||||||
|
import numpy as np
|
||||||
|
assert type(key) == tuple
|
||||||
|
assert type(key[0]) == np.int64
|
||||||
|
assert type(key[1]) == np.int32
|
||||||
|
|
||||||
|
return pdf.assign(v1=key[0],
|
||||||
|
v2=key[1],
|
||||||
|
v3=pdf.v * key[0],
|
||||||
|
v4=pdf.v + key[1])
|
||||||
|
|
||||||
|
def foo3(key, pdf):
|
||||||
|
assert type(key) == tuple
|
||||||
|
assert len(key) == 0
|
||||||
|
return pdf.assign(v1=pdf.v * pdf.id)
|
||||||
|
|
||||||
|
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
|
||||||
|
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
|
||||||
|
udf1 = pandas_udf(
|
||||||
|
foo1,
|
||||||
|
'id long, v int, v1 long, v2 int, v3 long, v4 double',
|
||||||
|
PandasUDFType.GROUPED_MAP)
|
||||||
|
|
||||||
|
udf2 = pandas_udf(
|
||||||
|
foo2,
|
||||||
|
'id long, v int, v1 long, v2 int, v3 int, v4 int',
|
||||||
|
PandasUDFType.GROUPED_MAP)
|
||||||
|
|
||||||
|
udf3 = pandas_udf(
|
||||||
|
foo3,
|
||||||
|
'id long, v int, v1 long',
|
||||||
|
PandasUDFType.GROUPED_MAP)
|
||||||
|
|
||||||
|
# Test groupby column
|
||||||
|
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
|
||||||
|
expected1 = pdf.groupby('id')\
|
||||||
|
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
|
||||||
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||||
|
self.assertPandasEqual(expected1, result1)
|
||||||
|
|
||||||
|
# Test groupby expression
|
||||||
|
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
|
||||||
|
expected2 = pdf.groupby(pdf.id % 2)\
|
||||||
|
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
|
||||||
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||||
|
self.assertPandasEqual(expected2, result2)
|
||||||
|
|
||||||
|
# Test complex groupby
|
||||||
|
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
|
||||||
|
expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
|
||||||
|
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
|
||||||
|
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||||
|
self.assertPandasEqual(expected3, result3)
|
||||||
|
|
||||||
|
# Test empty groupby
|
||||||
|
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
|
||||||
|
expected4 = udf3.func((), pdf)
|
||||||
|
self.assertPandasEqual(expected4, result4)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not _have_pandas or not _have_pyarrow,
|
not _have_pandas or not _have_pyarrow,
|
||||||
|
|
|
@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
|
||||||
for field in arrow_schema])
|
for field in arrow_schema])
|
||||||
|
|
||||||
|
|
||||||
|
def _check_series_convert_date(series, data_type):
|
||||||
|
"""
|
||||||
|
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
|
||||||
|
|
||||||
|
:param series: pandas.Series
|
||||||
|
:param data_type: a Spark data type for the series
|
||||||
|
"""
|
||||||
|
if type(data_type) == DateType:
|
||||||
|
return series.dt.date
|
||||||
|
else:
|
||||||
|
return series
|
||||||
|
|
||||||
|
|
||||||
def _check_dataframe_convert_date(pdf, schema):
|
def _check_dataframe_convert_date(pdf, schema):
|
||||||
""" Correct date type value to use datetime.date.
|
""" Correct date type value to use datetime.date.
|
||||||
|
|
||||||
|
@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
|
||||||
:param schema: a Spark schema of the pandas.DataFrame
|
:param schema: a Spark schema of the pandas.DataFrame
|
||||||
"""
|
"""
|
||||||
for field in schema:
|
for field in schema:
|
||||||
if type(field.dataType) == DateType:
|
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
|
||||||
pdf[field.name] = pdf[field.name].dt.date
|
|
||||||
return pdf
|
return pdf
|
||||||
|
|
||||||
|
|
||||||
|
@ -1725,6 +1737,29 @@ def _get_local_timezone():
|
||||||
return os.environ.get('TZ', 'dateutil/:')
|
return os.environ.get('TZ', 'dateutil/:')
|
||||||
|
|
||||||
|
|
||||||
|
def _check_series_localize_timestamps(s, timezone):
|
||||||
|
"""
|
||||||
|
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
|
||||||
|
|
||||||
|
If the input series is not a timestamp series, then the same series is returned. If the input
|
||||||
|
series is a timestamp series, then a converted series is returned.
|
||||||
|
|
||||||
|
:param s: pandas.Series
|
||||||
|
:param timezone: the timezone to convert. if None then use local timezone
|
||||||
|
:return pandas.Series that have been converted to tz-naive
|
||||||
|
"""
|
||||||
|
from pyspark.sql.utils import require_minimum_pandas_version
|
||||||
|
require_minimum_pandas_version()
|
||||||
|
|
||||||
|
from pandas.api.types import is_datetime64tz_dtype
|
||||||
|
tz = timezone or _get_local_timezone()
|
||||||
|
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||||
|
if is_datetime64tz_dtype(s.dtype):
|
||||||
|
return s.dt.tz_convert(tz).dt.tz_localize(None)
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def _check_dataframe_localize_timestamps(pdf, timezone):
|
def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||||
"""
|
"""
|
||||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
||||||
|
@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||||
from pyspark.sql.utils import require_minimum_pandas_version
|
from pyspark.sql.utils import require_minimum_pandas_version
|
||||||
require_minimum_pandas_version()
|
require_minimum_pandas_version()
|
||||||
|
|
||||||
from pandas.api.types import is_datetime64tz_dtype
|
|
||||||
tz = timezone or _get_local_timezone()
|
|
||||||
for column, series in pdf.iteritems():
|
for column, series in pdf.iteritems():
|
||||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
pdf[column] = _check_series_localize_timestamps(series, timezone)
|
||||||
if is_datetime64tz_dtype(series.dtype):
|
|
||||||
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
|
|
||||||
return pdf
|
return pdf
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
"""
|
"""
|
||||||
User-defined function related classes and functions
|
User-defined function related classes and functions
|
||||||
"""
|
"""
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from pyspark import SparkContext, since
|
from pyspark import SparkContext, since
|
||||||
|
@ -24,6 +26,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
|
||||||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||||
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
|
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
|
||||||
_parse_datatype_string, to_arrow_type, to_arrow_schema
|
_parse_datatype_string, to_arrow_type, to_arrow_schema
|
||||||
|
from pyspark.util import _get_argspec
|
||||||
|
|
||||||
__all__ = ["UDFRegistration"]
|
__all__ = ["UDFRegistration"]
|
||||||
|
|
||||||
|
@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
|
||||||
|
|
||||||
import inspect
|
|
||||||
import sys
|
|
||||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||||
|
|
||||||
require_minimum_pyarrow_version()
|
require_minimum_pyarrow_version()
|
||||||
|
|
||||||
if sys.version_info[0] < 3:
|
argspec = _get_argspec(f)
|
||||||
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
|
|
||||||
# See SPARK-23569.
|
|
||||||
argspec = inspect.getargspec(f)
|
|
||||||
else:
|
|
||||||
argspec = inspect.getfullargspec(f)
|
|
||||||
|
|
||||||
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
|
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
|
||||||
argspec.varargs is None:
|
argspec.varargs is None:
|
||||||
|
@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
|
||||||
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
||||||
)
|
)
|
||||||
|
|
||||||
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
|
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
|
||||||
|
and len(argspec.args) not in (1, 2):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
||||||
"must take a single arg that is a pandas DataFrame."
|
"must take either one argument (data) or two arguments (key, data).")
|
||||||
)
|
|
||||||
|
|
||||||
# Set the name of the UserDefinedFunction object to be the name of function f
|
# Set the name of the UserDefinedFunction object to be the name of function f
|
||||||
udf_obj = UserDefinedFunction(
|
udf_obj = UserDefinedFunction(
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import inspect
|
||||||
from py4j.protocol import Py4JJavaError
|
from py4j.protocol import Py4JJavaError
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
@ -45,6 +48,19 @@ def _exception_message(excp):
|
||||||
return str(excp)
|
return str(excp)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_argspec(f):
|
||||||
|
"""
|
||||||
|
Get argspec of a function. Supports both Python 2 and Python 3.
|
||||||
|
"""
|
||||||
|
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
|
||||||
|
# See SPARK-23569.
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
argspec = inspect.getargspec(f)
|
||||||
|
else:
|
||||||
|
argspec = inspect.getfullargspec(f)
|
||||||
|
return argspec
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
(failure_count, test_count) = doctest.testmod()
|
(failure_count, test_count) = doctest.testmod()
|
||||||
|
|
|
@ -34,6 +34,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \
|
||||||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
||||||
BatchedSerializer, ArrowStreamPandasSerializer
|
BatchedSerializer, ArrowStreamPandasSerializer
|
||||||
from pyspark.sql.types import to_arrow_type
|
from pyspark.sql.types import to_arrow_type
|
||||||
|
from pyspark.util import _get_argspec
|
||||||
from pyspark import shuffle
|
from pyspark import shuffle
|
||||||
|
|
||||||
pickleSer = PickleSerializer()
|
pickleSer = PickleSerializer()
|
||||||
|
@ -91,10 +92,16 @@ def wrap_scalar_pandas_udf(f, return_type):
|
||||||
|
|
||||||
|
|
||||||
def wrap_grouped_map_pandas_udf(f, return_type):
|
def wrap_grouped_map_pandas_udf(f, return_type):
|
||||||
def wrapped(*series):
|
def wrapped(key_series, value_series):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
argspec = _get_argspec(f)
|
||||||
|
|
||||||
|
if len(argspec.args) == 1:
|
||||||
|
result = f(pd.concat(value_series, axis=1))
|
||||||
|
elif len(argspec.args) == 2:
|
||||||
|
key = tuple(s[0] for s in key_series)
|
||||||
|
result = f(key, pd.concat(value_series, axis=1))
|
||||||
|
|
||||||
result = f(pd.concat(series, axis=1))
|
|
||||||
if not isinstance(result, pd.DataFrame):
|
if not isinstance(result, pd.DataFrame):
|
||||||
raise TypeError("Return type of the user-defined function should be "
|
raise TypeError("Return type of the user-defined function should be "
|
||||||
"pandas.DataFrame, but is {}".format(type(result)))
|
"pandas.DataFrame, but is {}".format(type(result)))
|
||||||
|
@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
num_udfs = read_int(infile)
|
num_udfs = read_int(infile)
|
||||||
udfs = {}
|
udfs = {}
|
||||||
call_udf = []
|
call_udf = []
|
||||||
for i in range(num_udfs):
|
mapper_str = ""
|
||||||
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||||
udfs['f%d' % i] = udf
|
# Create function like this:
|
||||||
args = ["a[%d]" % o for o in arg_offsets]
|
# lambda a: f([a[0]], [a[0], a[1]])
|
||||||
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
|
|
||||||
# Create function like this:
|
|
||||||
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
|
|
||||||
# In the special case of a single UDF this will return a single result rather
|
|
||||||
# than a tuple of results; this is the format that the JVM side expects.
|
|
||||||
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
|
|
||||||
mapper = eval(mapper_str, udfs)
|
|
||||||
|
|
||||||
|
# We assume there is only one UDF here because grouped map doesn't
|
||||||
|
# support combining multiple UDFs.
|
||||||
|
assert num_udfs == 1
|
||||||
|
|
||||||
|
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
|
||||||
|
# distinguish between grouping attributes and data attributes
|
||||||
|
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
||||||
|
udfs['f'] = udf
|
||||||
|
split_offset = arg_offsets[0] + 1
|
||||||
|
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
|
||||||
|
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
|
||||||
|
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
|
||||||
|
else:
|
||||||
|
# Create function like this:
|
||||||
|
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
|
||||||
|
# In the special case of a single UDF this will return a single result rather
|
||||||
|
# than a tuple of results; this is the format that the JVM side expects.
|
||||||
|
for i in range(num_udfs):
|
||||||
|
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
||||||
|
udfs['f%d' % i] = udf
|
||||||
|
args = ["a[%d]" % o for o in arg_offsets]
|
||||||
|
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
|
||||||
|
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
|
||||||
|
|
||||||
|
mapper = eval(mapper_str, udfs)
|
||||||
func = lambda _, it: map(mapper, it)
|
func = lambda _, it: map(mapper, it)
|
||||||
|
|
||||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.sql.execution.python
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
||||||
|
@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
|
||||||
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
|
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
|
||||||
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
|
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
|
||||||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||||
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
|
|
||||||
val schema = StructType(child.schema.drop(groupingAttributes.length))
|
|
||||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||||
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
|
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
|
||||||
|
|
||||||
|
// Deduplicate the grouping attributes.
|
||||||
|
// If a grouping attribute also appears in data attributes, then we don't need to send the
|
||||||
|
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
|
||||||
|
// then we need to send this grouping attribute to python worker.
|
||||||
|
//
|
||||||
|
// We use argOffsets to distinguish grouping attributes and data attributes as following:
|
||||||
|
//
|
||||||
|
// argOffsets[0] is the length of grouping attributes
|
||||||
|
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
|
||||||
|
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
|
||||||
|
|
||||||
|
val dataAttributes = child.output.drop(groupingAttributes.length)
|
||||||
|
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
||||||
|
dataAttributes.indexWhere(attribute.semanticEquals)
|
||||||
|
}
|
||||||
|
|
||||||
|
val groupingArgOffsets = new ArrayBuffer[Int]
|
||||||
|
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
||||||
|
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
||||||
|
|
||||||
|
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
|
||||||
|
// their offsets are 0, 1, 2 ...
|
||||||
|
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
|
||||||
|
// their offsets are n + index, where n is the total number of non duplicate grouping
|
||||||
|
// attributes and index is the index in the data attributes that the grouping attribute
|
||||||
|
// is a duplicate of.
|
||||||
|
|
||||||
|
groupingAttributes.zip(groupingIndicesInData).foreach {
|
||||||
|
case (attribute, index) =>
|
||||||
|
if (index == -1) {
|
||||||
|
groupingArgOffsets += nonDupGroupingAttributes.length
|
||||||
|
nonDupGroupingAttributes += attribute
|
||||||
|
} else {
|
||||||
|
groupingArgOffsets += index + nonDupGroupingSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dataArgOffsets = nonDupGroupingAttributes.length until
|
||||||
|
(nonDupGroupingAttributes.length + dataAttributes.length)
|
||||||
|
|
||||||
|
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
|
||||||
|
|
||||||
|
// Attributes after deduplication
|
||||||
|
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
||||||
|
val dedupSchema = StructType.fromAttributes(dedupAttributes)
|
||||||
|
|
||||||
inputRDD.mapPartitionsInternal { iter =>
|
inputRDD.mapPartitionsInternal { iter =>
|
||||||
val grouped = if (groupingAttributes.isEmpty) {
|
val grouped = if (groupingAttributes.isEmpty) {
|
||||||
Iterator(iter)
|
Iterator(iter)
|
||||||
} else {
|
} else {
|
||||||
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
|
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
|
||||||
val dropGrouping =
|
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
|
||||||
UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
|
|
||||||
groupedIter.map {
|
groupedIter.map {
|
||||||
case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
|
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
|
||||||
|
|
||||||
val columnarBatchIter = new ArrowPythonRunner(
|
val columnarBatchIter = new ArrowPythonRunner(
|
||||||
chainedFunc, bufferSize, reuseWorker,
|
chainedFunc, bufferSize, reuseWorker,
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
|
||||||
sessionLocalTimeZone, pandasRespectSessionTimeZone)
|
sessionLocalTimeZone, pandasRespectSessionTimeZone)
|
||||||
.compute(grouped, context.partitionId(), context)
|
.compute(grouped, context.partitionId(), context)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue