[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF

## What changes were proposed in this pull request?

This PR proposes to support an alternative function from with group aggregate pandas UDF.

The current form:
```
def foo(pdf):
    return ...
```
Takes a single arg that is a pandas DataFrame.

With this PR, an alternative form is supported:
```
def foo(key, pdf):
    return ...
```
The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data.

## How was this patch tested?

GroupbyApplyTests

Author: Li Jin <ice.xelloss@gmail.com>

Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.
This commit is contained in:
Li Jin 2018-03-08 20:29:07 +09:00 committed by hyukjinkwon
parent d6632d185e
commit 2cb23a8f51
8 changed files with 295 additions and 56 deletions

View file

@ -250,6 +250,15 @@ class ArrowStreamPandasSerializer(Serializer):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps
s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _check_series_localize_timestamps(s, self._timezone)
return s
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@ -272,16 +281,11 @@ class ArrowStreamPandasSerializer(Serializer):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
schema = from_arrow_schema(reader.schema)
for batch in reader:
pdf = batch.to_pandas()
pdf = _check_dataframe_convert_date(pdf, schema)
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
yield [c for _, c in pdf.iteritems()]
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
def __repr__(self):
return "ArrowStreamPandasSerializer"

View file

@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+
Alternatively, the user can define a function that takes two arguments.
In this case, the grouping key will be passed as the first argument and the data will
be passed as the second argument. The grouping key will be passed as a tuple of numpy
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
This is useful when the user does not want to hardcode grouping key in the function.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> import pandas as pd # doctest: +SKIP
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def mean_udf(key, pdf):
... # key is a tuple of one numpy.int64, which is the value
... # of 'id' for the current group
... return pd.DataFrame([key + (pdf.v.mean(),)])
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+---+---+
| id| v|
+---+---+
| 1|1.5|
| 2|6.0|
+---+---+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
3. GROUPED_AGG

View file

@ -3903,7 +3903,7 @@ class PandasUDFTests(ReusedSQLTestCase):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
def foo(k, v):
def foo(k, v, w):
return k
@ -4476,20 +4476,45 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))
foo_udf = pandas_udf(
# Different forms of group map pandas UDF, results of these are the same
output_schema = StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())])
udf1 = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType())),
StructField('v1', DoubleType()),
StructField('v2', LongType())]),
output_schema,
PandasUDFType.GROUPED_MAP
)
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertPandasEqual(expected, result)
udf2 = pandas_udf(
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
output_schema,
PandasUDFType.GROUPED_MAP
)
udf3 = pandas_udf(
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
output_schema,
PandasUDFType.GROUPED_MAP
)
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
expected2 = expected1
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
expected3 = expected1
self.assertPandasEqual(expected1, result1)
self.assertPandasEqual(expected2, result2)
self.assertPandasEqual(expected3, result3)
def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@ -4648,6 +4673,80 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
result = df.groupby('time').apply(foo_udf).sort('time')
self.assertPandasEqual(df.toPandas(), result.toPandas())
def test_udf_with_key(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
df = self.data
pdf = df.toPandas()
def foo1(key, pdf):
import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64
return pdf.assign(v1=key[0],
v2=pdf.v * key[0],
v3=pdf.v * pdf.id,
v4=pdf.v * pdf.id.mean())
def foo2(key, pdf):
import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32
return pdf.assign(v1=key[0],
v2=key[1],
v3=pdf.v * key[0],
v4=pdf.v + key[1])
def foo3(key, pdf):
assert type(key) == tuple
assert len(key) == 0
return pdf.assign(v1=pdf.v * pdf.id)
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
udf1 = pandas_udf(
foo1,
'id long, v int, v1 long, v2 int, v3 long, v4 double',
PandasUDFType.GROUPED_MAP)
udf2 = pandas_udf(
foo2,
'id long, v int, v1 long, v2 int, v3 int, v4 int',
PandasUDFType.GROUPED_MAP)
udf3 = pandas_udf(
foo3,
'id long, v int, v1 long',
PandasUDFType.GROUPED_MAP)
# Test groupby column
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
expected1 = pdf.groupby('id')\
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected1, result1)
# Test groupby expression
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
expected2 = pdf.groupby(pdf.id % 2)\
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected2, result2)
# Test complex groupby
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
self.assertPandasEqual(expected3, result3)
# Test empty groupby
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
expected4 = udf3.func((), pdf)
self.assertPandasEqual(expected4, result4)
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,

View file

@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
def _check_series_convert_date(series, data_type):
"""
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
:param series: pandas.Series
:param data_type: a Spark data type for the series
"""
if type(data_type) == DateType:
return series.dt.date
else:
return series
def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
:param schema: a Spark schema of the pandas.DataFrame
"""
for field in schema:
if type(field.dataType) == DateType:
pdf[field.name] = pdf[field.name].dt.date
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
@ -1725,6 +1737,29 @@ def _get_local_timezone():
return os.environ.get('TZ', 'dateutil/:')
def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
for column, series in pdf.iteritems():
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf

View file

@ -17,6 +17,8 @@
"""
User-defined function related classes and functions
"""
import sys
import inspect
import functools
from pyspark import SparkContext, since
@ -24,6 +26,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
_parse_datatype_string, to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec
__all__ = ["UDFRegistration"]
@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
import inspect
import sys
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
if sys.version_info[0] < 3:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
argspec = _get_argspec(f)
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
and len(argspec.args) not in (1, 2):
raise ValueError(
"Invalid function: pandas_udfs with function type GROUPED_MAP "
"must take a single arg that is a pandas DataFrame."
)
"must take either one argument (data) or two arguments (key, data).")
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(

View file

@ -15,6 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import inspect
from py4j.protocol import Py4JJavaError
__all__ = []
@ -45,6 +48,19 @@ def _exception_message(excp):
return str(excp)
def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
if sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
argspec = inspect.getfullargspec(f)
return argspec
if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()

View file

@ -34,6 +34,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.util import _get_argspec
from pyspark import shuffle
pickleSer = PickleSerializer()
@ -91,10 +92,16 @@ def wrap_scalar_pandas_udf(f, return_type):
def wrap_grouped_map_pandas_udf(f, return_type):
def wrapped(*series):
def wrapped(key_series, value_series):
import pandas as pd
argspec = _get_argspec(f)
if len(argspec.args) == 1:
result = f(pd.concat(value_series, axis=1))
elif len(argspec.args) == 2:
key = tuple(s[0] for s in key_series)
result = f(key, pd.concat(value_series, axis=1))
result = f(pd.concat(series, axis=1))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type):
num_udfs = read_int(infile)
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)
mapper_str = ""
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
# Create function like this:
# lambda a: f([a[0]], [a[0], a[1]])
# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
udfs['f'] = udf
split_offset = arg_offsets[0] + 1
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
else:
# Create function like this:
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)
func = lambda _, it: map(mapper, it)
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
val schema = StructType(child.schema.drop(groupingAttributes.length))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
// Deduplicate the grouping attributes.
// If a grouping attribute also appears in data attributes, then we don't need to send the
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
// then we need to send this grouping attribute to python worker.
//
// We use argOffsets to distinguish grouping attributes and data attributes as following:
//
// argOffsets[0] is the length of grouping attributes
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
val dataAttributes = child.output.drop(groupingAttributes.length)
val groupingIndicesInData = groupingAttributes.map { attribute =>
dataAttributes.indexWhere(attribute.semanticEquals)
}
val groupingArgOffsets = new ArrayBuffer[Int]
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
// their offsets are 0, 1, 2 ...
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
// their offsets are n + index, where n is the total number of non duplicate grouping
// attributes and index is the index in the data attributes that the grouping attribute
// is a duplicate of.
groupingAttributes.zip(groupingIndicesInData).foreach {
case (attribute, index) =>
if (index == -1) {
groupingArgOffsets += nonDupGroupingAttributes.length
nonDupGroupingAttributes += attribute
} else {
groupingArgOffsets += index + nonDupGroupingSize
}
}
val dataArgOffsets = nonDupGroupingAttributes.length until
(nonDupGroupingAttributes.length + dataAttributes.length)
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
// Attributes after deduplication
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
val dedupSchema = StructType.fromAttributes(dedupAttributes)
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dropGrouping =
UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)