[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF
## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin <ice.xelloss@gmail.com> Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.
This commit is contained in:
parent
d6632d185e
commit
2cb23a8f51
|
@ -250,6 +250,15 @@ class ArrowStreamPandasSerializer(Serializer):
|
|||
super(ArrowStreamPandasSerializer, self).__init__()
|
||||
self._timezone = timezone
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
from pyspark.sql.types import from_arrow_type, \
|
||||
_check_series_convert_date, _check_series_localize_timestamps
|
||||
|
||||
s = arrow_column.to_pandas()
|
||||
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
|
||||
s = _check_series_localize_timestamps(s, self._timezone)
|
||||
return s
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
|
||||
|
@ -272,16 +281,11 @@ class ArrowStreamPandasSerializer(Serializer):
|
|||
"""
|
||||
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
||||
"""
|
||||
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
|
||||
_check_dataframe_localize_timestamps
|
||||
import pyarrow as pa
|
||||
reader = pa.open_stream(stream)
|
||||
schema = from_arrow_schema(reader.schema)
|
||||
|
||||
for batch in reader:
|
||||
pdf = batch.to_pandas()
|
||||
pdf = _check_dataframe_convert_date(pdf, schema)
|
||||
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
|
||||
yield [c for _, c in pdf.iteritems()]
|
||||
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamPandasSerializer"
|
||||
|
|
|
@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
| 2| 1.1094003924504583|
|
||||
+---+-------------------+
|
||||
|
||||
Alternatively, the user can define a function that takes two arguments.
|
||||
In this case, the grouping key will be passed as the first argument and the data will
|
||||
be passed as the second argument. The grouping key will be passed as a tuple of numpy
|
||||
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
|
||||
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
|
||||
This is useful when the user does not want to hardcode grouping key in the function.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> import pandas as pd # doctest: +SKIP
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v")) # doctest: +SKIP
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def mean_udf(key, pdf):
|
||||
... # key is a tuple of one numpy.int64, which is the value
|
||||
... # of 'id' for the current group
|
||||
... return pd.DataFrame([key + (pdf.v.mean(),)])
|
||||
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id| v|
|
||||
+---+---+
|
||||
| 1|1.5|
|
||||
| 2|6.0|
|
||||
+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
|
||||
|
||||
3. GROUPED_AGG
|
||||
|
|
|
@ -3903,7 +3903,7 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
return df
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
|
||||
def foo(k, v):
|
||||
def foo(k, v, w):
|
||||
return k
|
||||
|
||||
|
||||
|
@ -4476,20 +4476,45 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
|||
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
|
||||
df = self.data.withColumn("arr", array(col("id")))
|
||||
|
||||
foo_udf = pandas_udf(
|
||||
# Different forms of group map pandas UDF, results of these are the same
|
||||
|
||||
output_schema = StructType(
|
||||
[StructField('id', LongType()),
|
||||
StructField('v', IntegerType()),
|
||||
StructField('arr', ArrayType(LongType())),
|
||||
StructField('v1', DoubleType()),
|
||||
StructField('v2', LongType())])
|
||||
|
||||
udf1 = pandas_udf(
|
||||
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||
StructType(
|
||||
[StructField('id', LongType()),
|
||||
StructField('v', IntegerType()),
|
||||
StructField('arr', ArrayType(LongType())),
|
||||
StructField('v1', DoubleType()),
|
||||
StructField('v2', LongType())]),
|
||||
output_schema,
|
||||
PandasUDFType.GROUPED_MAP
|
||||
)
|
||||
|
||||
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
|
||||
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
|
||||
self.assertPandasEqual(expected, result)
|
||||
udf2 = pandas_udf(
|
||||
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||
output_schema,
|
||||
PandasUDFType.GROUPED_MAP
|
||||
)
|
||||
|
||||
udf3 = pandas_udf(
|
||||
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
|
||||
output_schema,
|
||||
PandasUDFType.GROUPED_MAP
|
||||
)
|
||||
|
||||
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
|
||||
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
|
||||
|
||||
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
|
||||
expected2 = expected1
|
||||
|
||||
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
|
||||
expected3 = expected1
|
||||
|
||||
self.assertPandasEqual(expected1, result1)
|
||||
self.assertPandasEqual(expected2, result2)
|
||||
self.assertPandasEqual(expected3, result3)
|
||||
|
||||
def test_register_grouped_map_udf(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
@ -4648,6 +4673,80 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
|||
result = df.groupby('time').apply(foo_udf).sort('time')
|
||||
self.assertPandasEqual(df.toPandas(), result.toPandas())
|
||||
|
||||
def test_udf_with_key(self):
|
||||
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
|
||||
df = self.data
|
||||
pdf = df.toPandas()
|
||||
|
||||
def foo1(key, pdf):
|
||||
import numpy as np
|
||||
assert type(key) == tuple
|
||||
assert type(key[0]) == np.int64
|
||||
|
||||
return pdf.assign(v1=key[0],
|
||||
v2=pdf.v * key[0],
|
||||
v3=pdf.v * pdf.id,
|
||||
v4=pdf.v * pdf.id.mean())
|
||||
|
||||
def foo2(key, pdf):
|
||||
import numpy as np
|
||||
assert type(key) == tuple
|
||||
assert type(key[0]) == np.int64
|
||||
assert type(key[1]) == np.int32
|
||||
|
||||
return pdf.assign(v1=key[0],
|
||||
v2=key[1],
|
||||
v3=pdf.v * key[0],
|
||||
v4=pdf.v + key[1])
|
||||
|
||||
def foo3(key, pdf):
|
||||
assert type(key) == tuple
|
||||
assert len(key) == 0
|
||||
return pdf.assign(v1=pdf.v * pdf.id)
|
||||
|
||||
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
|
||||
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
|
||||
udf1 = pandas_udf(
|
||||
foo1,
|
||||
'id long, v int, v1 long, v2 int, v3 long, v4 double',
|
||||
PandasUDFType.GROUPED_MAP)
|
||||
|
||||
udf2 = pandas_udf(
|
||||
foo2,
|
||||
'id long, v int, v1 long, v2 int, v3 int, v4 int',
|
||||
PandasUDFType.GROUPED_MAP)
|
||||
|
||||
udf3 = pandas_udf(
|
||||
foo3,
|
||||
'id long, v int, v1 long',
|
||||
PandasUDFType.GROUPED_MAP)
|
||||
|
||||
# Test groupby column
|
||||
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
|
||||
expected1 = pdf.groupby('id')\
|
||||
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
|
||||
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||
self.assertPandasEqual(expected1, result1)
|
||||
|
||||
# Test groupby expression
|
||||
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
|
||||
expected2 = pdf.groupby(pdf.id % 2)\
|
||||
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
|
||||
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||
self.assertPandasEqual(expected2, result2)
|
||||
|
||||
# Test complex groupby
|
||||
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
|
||||
expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
|
||||
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
|
||||
.sort_values(['id', 'v']).reset_index(drop=True)
|
||||
self.assertPandasEqual(expected3, result3)
|
||||
|
||||
# Test empty groupby
|
||||
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
|
||||
expected4 = udf3.func((), pdf)
|
||||
self.assertPandasEqual(expected4, result4)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not _have_pandas or not _have_pyarrow,
|
||||
|
|
|
@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
|
|||
for field in arrow_schema])
|
||||
|
||||
|
||||
def _check_series_convert_date(series, data_type):
|
||||
"""
|
||||
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
|
||||
|
||||
:param series: pandas.Series
|
||||
:param data_type: a Spark data type for the series
|
||||
"""
|
||||
if type(data_type) == DateType:
|
||||
return series.dt.date
|
||||
else:
|
||||
return series
|
||||
|
||||
|
||||
def _check_dataframe_convert_date(pdf, schema):
|
||||
""" Correct date type value to use datetime.date.
|
||||
|
||||
|
@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
|
|||
:param schema: a Spark schema of the pandas.DataFrame
|
||||
"""
|
||||
for field in schema:
|
||||
if type(field.dataType) == DateType:
|
||||
pdf[field.name] = pdf[field.name].dt.date
|
||||
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
|
||||
return pdf
|
||||
|
||||
|
||||
|
@ -1725,6 +1737,29 @@ def _get_local_timezone():
|
|||
return os.environ.get('TZ', 'dateutil/:')
|
||||
|
||||
|
||||
def _check_series_localize_timestamps(s, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
|
||||
|
||||
If the input series is not a timestamp series, then the same series is returned. If the input
|
||||
series is a timestamp series, then a converted series is returned.
|
||||
|
||||
:param s: pandas.Series
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.Series that have been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64tz_dtype
|
||||
tz = timezone or _get_local_timezone()
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert(tz).dt.tz_localize(None)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
||||
|
@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
|
|||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64tz_dtype
|
||||
tz = timezone or _get_local_timezone()
|
||||
for column, series in pdf.iteritems():
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(series.dtype):
|
||||
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
|
||||
pdf[column] = _check_series_localize_timestamps(series, timezone)
|
||||
return pdf
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
"""
|
||||
User-defined function related classes and functions
|
||||
"""
|
||||
import sys
|
||||
import inspect
|
||||
import functools
|
||||
|
||||
from pyspark import SparkContext, since
|
||||
|
@ -24,6 +26,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
|
|||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
|
||||
_parse_datatype_string, to_arrow_type, to_arrow_schema
|
||||
from pyspark.util import _get_argspec
|
||||
|
||||
__all__ = ["UDFRegistration"]
|
||||
|
||||
|
@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
|
|||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pyarrow_version()
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
|
||||
# See SPARK-23569.
|
||||
argspec = inspect.getargspec(f)
|
||||
else:
|
||||
argspec = inspect.getfullargspec(f)
|
||||
argspec = _get_argspec(f)
|
||||
|
||||
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
|
||||
argspec.varargs is None:
|
||||
|
@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
|
|||
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
||||
)
|
||||
|
||||
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
|
||||
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
|
||||
and len(argspec.args) not in (1, 2):
|
||||
raise ValueError(
|
||||
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
||||
"must take a single arg that is a pandas DataFrame."
|
||||
)
|
||||
"must take either one argument (data) or two arguments (key, data).")
|
||||
|
||||
# Set the name of the UserDefinedFunction object to be the name of function f
|
||||
udf_obj = UserDefinedFunction(
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import inspect
|
||||
from py4j.protocol import Py4JJavaError
|
||||
|
||||
__all__ = []
|
||||
|
@ -45,6 +48,19 @@ def _exception_message(excp):
|
|||
return str(excp)
|
||||
|
||||
|
||||
def _get_argspec(f):
|
||||
"""
|
||||
Get argspec of a function. Supports both Python 2 and Python 3.
|
||||
"""
|
||||
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
|
||||
# See SPARK-23569.
|
||||
if sys.version_info[0] < 3:
|
||||
argspec = inspect.getargspec(f)
|
||||
else:
|
||||
argspec = inspect.getfullargspec(f)
|
||||
return argspec
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
(failure_count, test_count) = doctest.testmod()
|
||||
|
|
|
@ -34,6 +34,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \
|
|||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
||||
BatchedSerializer, ArrowStreamPandasSerializer
|
||||
from pyspark.sql.types import to_arrow_type
|
||||
from pyspark.util import _get_argspec
|
||||
from pyspark import shuffle
|
||||
|
||||
pickleSer = PickleSerializer()
|
||||
|
@ -91,10 +92,16 @@ def wrap_scalar_pandas_udf(f, return_type):
|
|||
|
||||
|
||||
def wrap_grouped_map_pandas_udf(f, return_type):
|
||||
def wrapped(*series):
|
||||
def wrapped(key_series, value_series):
|
||||
import pandas as pd
|
||||
argspec = _get_argspec(f)
|
||||
|
||||
if len(argspec.args) == 1:
|
||||
result = f(pd.concat(value_series, axis=1))
|
||||
elif len(argspec.args) == 2:
|
||||
key = tuple(s[0] for s in key_series)
|
||||
result = f(key, pd.concat(value_series, axis=1))
|
||||
|
||||
result = f(pd.concat(series, axis=1))
|
||||
if not isinstance(result, pd.DataFrame):
|
||||
raise TypeError("Return type of the user-defined function should be "
|
||||
"pandas.DataFrame, but is {}".format(type(result)))
|
||||
|
@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
num_udfs = read_int(infile)
|
||||
udfs = {}
|
||||
call_udf = []
|
||||
for i in range(num_udfs):
|
||||
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
||||
udfs['f%d' % i] = udf
|
||||
args = ["a[%d]" % o for o in arg_offsets]
|
||||
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
|
||||
# Create function like this:
|
||||
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
|
||||
# In the special case of a single UDF this will return a single result rather
|
||||
# than a tuple of results; this is the format that the JVM side expects.
|
||||
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
|
||||
mapper = eval(mapper_str, udfs)
|
||||
mapper_str = ""
|
||||
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||
# Create function like this:
|
||||
# lambda a: f([a[0]], [a[0], a[1]])
|
||||
|
||||
# We assume there is only one UDF here because grouped map doesn't
|
||||
# support combining multiple UDFs.
|
||||
assert num_udfs == 1
|
||||
|
||||
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
|
||||
# distinguish between grouping attributes and data attributes
|
||||
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
||||
udfs['f'] = udf
|
||||
split_offset = arg_offsets[0] + 1
|
||||
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
|
||||
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
|
||||
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
|
||||
else:
|
||||
# Create function like this:
|
||||
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
|
||||
# In the special case of a single UDF this will return a single result rather
|
||||
# than a tuple of results; this is the format that the JVM side expects.
|
||||
for i in range(num_udfs):
|
||||
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
|
||||
udfs['f%d' % i] = udf
|
||||
args = ["a[%d]" % o for o in arg_offsets]
|
||||
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
|
||||
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
|
||||
|
||||
mapper = eval(mapper_str, udfs)
|
||||
func = lambda _, it: map(mapper, it)
|
||||
|
||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
||||
|
@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
|
|||
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
|
||||
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
|
||||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||
val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
|
||||
val schema = StructType(child.schema.drop(groupingAttributes.length))
|
||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
|
||||
|
||||
// Deduplicate the grouping attributes.
|
||||
// If a grouping attribute also appears in data attributes, then we don't need to send the
|
||||
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
|
||||
// then we need to send this grouping attribute to python worker.
|
||||
//
|
||||
// We use argOffsets to distinguish grouping attributes and data attributes as following:
|
||||
//
|
||||
// argOffsets[0] is the length of grouping attributes
|
||||
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
|
||||
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
|
||||
|
||||
val dataAttributes = child.output.drop(groupingAttributes.length)
|
||||
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
||||
dataAttributes.indexWhere(attribute.semanticEquals)
|
||||
}
|
||||
|
||||
val groupingArgOffsets = new ArrayBuffer[Int]
|
||||
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
||||
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
||||
|
||||
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
|
||||
// their offsets are 0, 1, 2 ...
|
||||
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
|
||||
// their offsets are n + index, where n is the total number of non duplicate grouping
|
||||
// attributes and index is the index in the data attributes that the grouping attribute
|
||||
// is a duplicate of.
|
||||
|
||||
groupingAttributes.zip(groupingIndicesInData).foreach {
|
||||
case (attribute, index) =>
|
||||
if (index == -1) {
|
||||
groupingArgOffsets += nonDupGroupingAttributes.length
|
||||
nonDupGroupingAttributes += attribute
|
||||
} else {
|
||||
groupingArgOffsets += index + nonDupGroupingSize
|
||||
}
|
||||
}
|
||||
|
||||
val dataArgOffsets = nonDupGroupingAttributes.length until
|
||||
(nonDupGroupingAttributes.length + dataAttributes.length)
|
||||
|
||||
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
|
||||
|
||||
// Attributes after deduplication
|
||||
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
||||
val dedupSchema = StructType.fromAttributes(dedupAttributes)
|
||||
|
||||
inputRDD.mapPartitionsInternal { iter =>
|
||||
val grouped = if (groupingAttributes.isEmpty) {
|
||||
Iterator(iter)
|
||||
} else {
|
||||
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
|
||||
val dropGrouping =
|
||||
UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
|
||||
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
|
||||
groupedIter.map {
|
||||
case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
|
||||
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
|
|||
|
||||
val columnarBatchIter = new ArrowPythonRunner(
|
||||
chainedFunc, bufferSize, reuseWorker,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
|
||||
sessionLocalTimeZone, pandasRespectSessionTimeZone)
|
||||
.compute(grouped, context.partitionId(), context)
|
||||
|
||||
|
|
Loading…
Reference in a new issue