[SPARK-27463][PYTHON] Support Dataframe Cogroup via Pandas UDFs
### What changes were proposed in this pull request? Adds a new cogroup Pandas UDF. This allows two grouped dataframes to be cogrouped together and apply a (pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame UDF to each cogroup. **Example usage** ``` from pyspark.sql.functions import pandas_udf, PandasUDFType df1 = spark.createDataFrame( [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ("time", "id", "v1")) df2 = spark.createDataFrame( [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) def asof_join(l, r): return pd.merge_asof(l, r, on="time", by="id") df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() ``` +--------+---+---+---+ | time| id| v1| v2| +--------+---+---+---+ |20000101| 1|1.0| x| |20000102| 1|3.0| x| |20000101| 2|2.0| y| |20000102| 2|4.0| y| +--------+---+---+---+ ### How was this patch tested? Added unit test test_pandas_udf_cogrouped_map Closes #24981 from d80tb7/SPARK-27463-poc-arrow-stream. Authored-by: Chris Martin <chris@cmartinit.co.uk> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
This commit is contained in:
parent
197732e1f4
commit
05988b256e
|
@ -48,6 +48,7 @@ private[spark] object PythonEvalType {
|
|||
val SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
val SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||
val SQL_MAP_PANDAS_ITER_UDF = 205
|
||||
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
|
||||
|
||||
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
||||
case NON_UDF => "NON_UDF"
|
||||
|
@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
|
|||
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
|
||||
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
|
||||
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
|
||||
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ class PythonEvalType(object):
|
|||
SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||
SQL_MAP_PANDAS_ITER_UDF = 205
|
||||
SQL_COGROUPED_MAP_PANDAS_UDF = 206
|
||||
|
||||
|
||||
def portable_hash(x):
|
||||
|
|
|
@ -401,6 +401,32 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
|
|||
return "ArrowStreamPandasUDFSerializer"
|
||||
|
||||
|
||||
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
|
||||
lists of pandas.Series.
|
||||
"""
|
||||
import pyarrow as pa
|
||||
dataframes_in_group = None
|
||||
|
||||
while dataframes_in_group is None or dataframes_in_group > 0:
|
||||
dataframes_in_group = read_int(stream)
|
||||
|
||||
if dataframes_in_group == 2:
|
||||
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
yield (
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
|
||||
)
|
||||
|
||||
elif dataframes_in_group != 0:
|
||||
raise ValueError(
|
||||
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
|
||||
|
||||
|
||||
class BatchedSerializer(Serializer):
|
||||
|
||||
"""
|
||||
|
|
98
python/pyspark/sql/cogroup.py
Normal file
98
python/pyspark/sql/cogroup.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import PythonEvalType
|
||||
from pyspark.sql.column import Column
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
|
||||
class CoGroupedData(object):
|
||||
"""
|
||||
A logical grouping of two :class:`GroupedData`,
|
||||
created by :func:`GroupedData.cogroup`.
|
||||
|
||||
.. note:: Experimental
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
def __init__(self, gd1, gd2):
|
||||
self._gd1 = gd1
|
||||
self._gd2 = gd2
|
||||
self.sql_ctx = gd1.sql_ctx
|
||||
|
||||
@since(3.0)
|
||||
def apply(self, udf):
|
||||
"""
|
||||
Applies a function to each cogroup using a pandas udf and returns the result
|
||||
as a `DataFrame`.
|
||||
|
||||
The user-defined function should take two `pandas.DataFrame` and return another
|
||||
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
|
||||
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
|
||||
are combined as a :class:`DataFrame`.
|
||||
|
||||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||
returnType of the pandas udf.
|
||||
|
||||
.. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
|
||||
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||
and certain groups are too large to fit in memory.
|
||||
|
||||
.. note:: Experimental
|
||||
|
||||
:param udf: a cogrouped map user-defined function returned by
|
||||
:func:`pyspark.sql.functions.pandas_udf`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df1 = spark.createDataFrame(
|
||||
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||
... ("time", "id", "v1"))
|
||||
>>> df2 = spark.createDataFrame(
|
||||
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||
... ("time", "id", "v2"))
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
|
||||
... def asof_join(l, r):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
|
||||
+--------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+--------+---+---+---+
|
||||
|20000101| 1|1.0| x|
|
||||
|20000102| 1|3.0| x|
|
||||
|20000101| 2|2.0| y|
|
||||
|20000102| 2|4.0| y|
|
||||
+--------+---+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"COGROUPED_MAP.")
|
||||
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
|
||||
udf_column = udf(*all_cols)
|
||||
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@staticmethod
|
||||
def _extract_cols(gd):
|
||||
df = gd._df
|
||||
return [df[col] for col in df.columns]
|
|
@ -2814,6 +2814,8 @@ class PandasUDFType(object):
|
|||
|
||||
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
|
||||
|
||||
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
|
||||
|
||||
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
||||
|
||||
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
||||
|
@ -3320,7 +3322,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
|
||||
raise ValueError("Invalid functionType: "
|
||||
"functionType must be one the values from PandasUDFType")
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
|
|||
from pyspark.sql.column import Column, _to_seq
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.sql.cogroup import CoGroupedData
|
||||
|
||||
__all__ = ["GroupedData"]
|
||||
|
||||
|
@ -218,6 +219,15 @@ class GroupedData(object):
|
|||
jgd = self._jgd.pivot(pivot_col, values)
|
||||
return GroupedData(jgd, self._df)
|
||||
|
||||
@since(3.0)
|
||||
def cogroup(self, other):
|
||||
"""
|
||||
Cogroups this group with another group so that we can run cogrouped operations.
|
||||
|
||||
See :class:`CoGroupedData` for the operations that can be run.
|
||||
"""
|
||||
return CoGroupedData(self, other)
|
||||
|
||||
@since(2.3)
|
||||
def apply(self, udf):
|
||||
"""
|
||||
|
@ -232,7 +242,7 @@ class GroupedData(object):
|
|||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||
returnType of the pandas udf.
|
||||
|
||||
.. note:: This function requires a full shuffle. all the data of a group will be loaded
|
||||
.. note:: This function requires a full shuffle. All the data of a group will be loaded
|
||||
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||
and certain groups are too large to fit in memory.
|
||||
|
||||
|
|
280
python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Normal file
280
python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
|
||||
from pyspark.sql.types import DoubleType, StructType, StructField
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||
pandas_requirement_message, pyarrow_requirement_message
|
||||
from pyspark.testing.utils import QuietTest
|
||||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal, assert_series_equal
|
||||
|
||||
if have_pyarrow:
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
"""
|
||||
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||
"""
|
||||
if sys.version < '3':
|
||||
_check_column_type = False
|
||||
else:
|
||||
_check_column_type = True
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not have_pandas or not have_pyarrow,
|
||||
pandas_requirement_message or pyarrow_requirement_message)
|
||||
class CoGroupedMapPandasUDFTests(ReusedSQLTestCase):
|
||||
|
||||
@property
|
||||
def data1(self):
|
||||
return self.spark.range(10).toDF('id') \
|
||||
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
|
||||
.withColumn("k", explode(col('ks')))\
|
||||
.withColumn("v", col('k') * 10)\
|
||||
.drop('ks')
|
||||
|
||||
@property
|
||||
def data2(self):
|
||||
return self.spark.range(10).toDF('id') \
|
||||
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
|
||||
.withColumn("k", explode(col('ks'))) \
|
||||
.withColumn("v2", col('k') * 100) \
|
||||
.drop('ks')
|
||||
|
||||
def test_simple(self):
|
||||
self._test_merge(self.data1, self.data2)
|
||||
|
||||
def test_left_group_empty(self):
|
||||
left = self.data1.where(col("id") % 2 == 0)
|
||||
self._test_merge(left, self.data2)
|
||||
|
||||
def test_right_group_empty(self):
|
||||
right = self.data2.where(col("id") % 2 == 0)
|
||||
self._test_merge(self.data1, right)
|
||||
|
||||
def test_different_schemas(self):
|
||||
right = self.data2.withColumn('v3', lit('a'))
|
||||
self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
|
||||
|
||||
def test_complex_group_by(self):
|
||||
left = pd.DataFrame.from_dict({
|
||||
'id': [1, 2, 3],
|
||||
'k': [5, 6, 7],
|
||||
'v': [9, 10, 11]
|
||||
})
|
||||
|
||||
right = pd.DataFrame.from_dict({
|
||||
'id': [11, 12, 13],
|
||||
'k': [5, 6, 7],
|
||||
'v2': [90, 100, 110]
|
||||
})
|
||||
|
||||
left_gdf = self.spark\
|
||||
.createDataFrame(left)\
|
||||
.groupby(col('id') % 2 == 0)
|
||||
|
||||
right_gdf = self.spark \
|
||||
.createDataFrame(right) \
|
||||
.groupby(col('id') % 2 == 0)
|
||||
|
||||
@pandas_udf('k long, v long, v2 long', PandasUDFType.COGROUPED_MAP)
|
||||
def merge_pandas(l, r):
|
||||
return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
|
||||
|
||||
result = left_gdf \
|
||||
.cogroup(right_gdf) \
|
||||
.apply(merge_pandas) \
|
||||
.sort(['k']) \
|
||||
.toPandas()
|
||||
|
||||
expected = pd.DataFrame.from_dict({
|
||||
'k': [5, 6, 7],
|
||||
'v': [9, 10, 11],
|
||||
'v2': [90, 100, 110]
|
||||
})
|
||||
|
||||
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||
|
||||
def test_empty_group_by(self):
|
||||
left = self.data1
|
||||
right = self.data2
|
||||
|
||||
@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
|
||||
def merge_pandas(l, r):
|
||||
return pd.merge(l, r, on=['id', 'k'])
|
||||
|
||||
result = left.groupby().cogroup(right.groupby())\
|
||||
.apply(merge_pandas) \
|
||||
.sort(['id', 'k']) \
|
||||
.toPandas()
|
||||
|
||||
left = left.toPandas()
|
||||
right = right.toPandas()
|
||||
|
||||
expected = pd \
|
||||
.merge(left, right, on=['id', 'k']) \
|
||||
.sort_values(by=['id', 'k'])
|
||||
|
||||
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||
|
||||
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
|
||||
df = self.spark.range(0, 10).toDF('v1')
|
||||
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
|
||||
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
|
||||
|
||||
result = df.groupby().cogroup(df.groupby())\
|
||||
.apply(pandas_udf(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
|
||||
'sum1 int, sum2 int',
|
||||
PandasUDFType.COGROUPED_MAP)).collect()
|
||||
|
||||
self.assertEquals(result[0]['sum1'], 165)
|
||||
self.assertEquals(result[0]['sum2'], 165)
|
||||
|
||||
def test_with_key_left(self):
|
||||
self._test_with_key(self.data1, self.data1, isLeft=True)
|
||||
|
||||
def test_with_key_right(self):
|
||||
self._test_with_key(self.data1, self.data1, isLeft=False)
|
||||
|
||||
def test_with_key_left_group_empty(self):
|
||||
left = self.data1.where(col("id") % 2 == 0)
|
||||
self._test_with_key(left, self.data1, isLeft=True)
|
||||
|
||||
def test_with_key_right_group_empty(self):
|
||||
right = self.data1.where(col("id") % 2 == 0)
|
||||
self._test_with_key(self.data1, right, isLeft=False)
|
||||
|
||||
def test_with_key_complex(self):
|
||||
|
||||
@pandas_udf('id long, k int, v int, key boolean', PandasUDFType.COGROUPED_MAP)
|
||||
def left_assign_key(key, l, _):
|
||||
return l.assign(key=key[0])
|
||||
|
||||
result = self.data1 \
|
||||
.groupby(col('id') % 2 == 0)\
|
||||
.cogroup(self.data2.groupby(col('id') % 2 == 0)) \
|
||||
.apply(left_assign_key) \
|
||||
.sort(['id', 'k']) \
|
||||
.toPandas()
|
||||
|
||||
expected = self.data1.toPandas()
|
||||
expected = expected.assign(key=expected.id % 2 == 0)
|
||||
|
||||
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||
|
||||
def test_wrong_return_type(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid returnType.*cogrouped map Pandas UDF.*MapType'):
|
||||
pandas_udf(
|
||||
lambda l, r: l,
|
||||
'id long, v map<int, int>',
|
||||
PandasUDFType.COGROUPED_MAP)
|
||||
|
||||
def test_wrong_args(self):
|
||||
# Test that we get a sensible exception invalid values passed to apply
|
||||
left = self.data1
|
||||
right = self.data2
|
||||
with QuietTest(self.sc):
|
||||
# Function rather than a udf
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
left.groupby('id').cogroup(right.groupby('id')).apply(lambda l, r: l)
|
||||
|
||||
# Udf missing return type
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
left.groupby('id').cogroup(right.groupby('id'))\
|
||||
.apply(udf(lambda l, r: l, DoubleType()))
|
||||
|
||||
# Pass in expression rather than udf
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
left.groupby('id').cogroup(right.groupby('id')).apply(left.v + 1)
|
||||
|
||||
# Zero arg function
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
left.groupby('id').cogroup(right.groupby('id'))\
|
||||
.apply(pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
|
||||
|
||||
# Udf without PandasUDFType
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
left.groupby('id').cogroup(right.groupby('id'))\
|
||||
.apply(pandas_udf(lambda x, y: x, DoubleType()))
|
||||
|
||||
# Udf with incorrect PandasUDFType
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*COGROUPED_MAP'):
|
||||
left.groupby('id').cogroup(right.groupby('id'))\
|
||||
.apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
|
||||
|
||||
@staticmethod
|
||||
def _test_with_key(left, right, isLeft):
|
||||
|
||||
@pandas_udf('id long, k int, v int, key long', PandasUDFType.COGROUPED_MAP)
|
||||
def right_assign_key(key, l, r):
|
||||
return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
|
||||
|
||||
result = left \
|
||||
.groupby('id') \
|
||||
.cogroup(right.groupby('id')) \
|
||||
.apply(right_assign_key) \
|
||||
.toPandas()
|
||||
|
||||
expected = left.toPandas() if isLeft else right.toPandas()
|
||||
expected = expected.assign(key=expected.id)
|
||||
|
||||
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||
|
||||
@staticmethod
|
||||
def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'):
|
||||
|
||||
@pandas_udf(output_schema, PandasUDFType.COGROUPED_MAP)
|
||||
def merge_pandas(l, r):
|
||||
return pd.merge(l, r, on=['id', 'k'])
|
||||
|
||||
result = left \
|
||||
.groupby('id') \
|
||||
.cogroup(right.groupby('id')) \
|
||||
.apply(merge_pandas)\
|
||||
.sort(['id', 'k']) \
|
||||
.toPandas()
|
||||
|
||||
left = left.toPandas()
|
||||
right = right.toPandas()
|
||||
|
||||
expected = pd \
|
||||
.merge(left, right, on=['id', 'k']) \
|
||||
.sort_values(by=['id', 'k'])
|
||||
|
||||
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.sql.tests.test_pandas_udf_cogrouped_map import *
|
||||
|
||||
try:
|
||||
import xmlrunner
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
|
@ -42,6 +42,7 @@ def _create_udf(f, returnType, evalType):
|
|||
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
|
||||
|
||||
|
@ -65,6 +66,13 @@ def _create_udf(f, returnType, evalType):
|
|||
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
||||
"must take either one argument (data) or two arguments (key, data).")
|
||||
|
||||
if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF \
|
||||
and len(argspec.args) not in (2, 3):
|
||||
raise ValueError(
|
||||
"Invalid function: pandas_udfs with function type COGROUPED_MAP "
|
||||
"must take either two arguments (left, right) "
|
||||
"or three arguments (key, left, right).")
|
||||
|
||||
# Set the name of the UserDefinedFunction object to be the name of function f
|
||||
udf_obj = UserDefinedFunction(
|
||||
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
||||
|
@ -147,6 +155,17 @@ class UserDefinedFunction(object):
|
|||
else:
|
||||
raise TypeError("Invalid returnType for map iterator Pandas "
|
||||
"UDFs: returnType must be a StructType.")
|
||||
elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||
if isinstance(self._returnType_placeholder, StructType):
|
||||
try:
|
||||
to_arrow_type(self._returnType_placeholder)
|
||||
except TypeError:
|
||||
raise NotImplementedError(
|
||||
"Invalid returnType with cogrouped map Pandas UDFs: "
|
||||
"%s is not supported" % str(self._returnType_placeholder))
|
||||
else:
|
||||
raise TypeError("Invalid returnType for cogrouped map Pandas "
|
||||
"UDFs: returnType must be a StructType.")
|
||||
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
||||
try:
|
||||
# StructType is not yet allowed as a return type, explicitly check here to fail fast
|
||||
|
|
|
@ -39,7 +39,7 @@ from pyspark.resourceinformation import ResourceInformation
|
|||
from pyspark.rdd import PythonEvalType
|
||||
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
|
||||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
||||
BatchedSerializer, ArrowStreamPandasUDFSerializer
|
||||
BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
|
||||
from pyspark.sql.types import to_arrow_type, StructType
|
||||
from pyspark.util import _get_argspec, fail_on_stopiteration
|
||||
from pyspark import shuffle
|
||||
|
@ -121,6 +121,33 @@ def wrap_pandas_iter_udf(f, return_type):
|
|||
map(verify_result_type, f(*iterator)))
|
||||
|
||||
|
||||
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
|
||||
|
||||
def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
|
||||
import pandas as pd
|
||||
|
||||
left_df = pd.concat(left_value_series, axis=1)
|
||||
right_df = pd.concat(right_value_series, axis=1)
|
||||
|
||||
if len(argspec.args) == 2:
|
||||
result = f(left_df, right_df)
|
||||
elif len(argspec.args) == 3:
|
||||
key_series = left_key_series if not left_df.empty else right_key_series
|
||||
key = tuple(s[0] for s in key_series)
|
||||
result = f(key, left_df, right_df)
|
||||
if not isinstance(result, pd.DataFrame):
|
||||
raise TypeError("Return type of the user-defined function should be "
|
||||
"pandas.DataFrame, but is {}".format(type(result)))
|
||||
if not len(result.columns) == len(return_type):
|
||||
raise RuntimeError(
|
||||
"Number of columns of the returned pandas.DataFrame "
|
||||
"doesn't match specified schema. "
|
||||
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
|
||||
return result
|
||||
|
||||
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))]
|
||||
|
||||
|
||||
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
|
||||
|
||||
def wrapped(key_series, value_series):
|
||||
|
@ -244,6 +271,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
|
|||
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
||||
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
|
||||
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
||||
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec)
|
||||
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
||||
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
|
||||
|
@ -258,6 +288,7 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
runner_conf = {}
|
||||
|
||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
|
@ -280,6 +311,9 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
|
||||
.lower() == "true"
|
||||
|
||||
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||
ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
|
||||
else:
|
||||
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
|
||||
# pandas Series. See SPARK-27240.
|
||||
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
|
||||
|
@ -343,6 +377,32 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
# profiling is not supported for UDF
|
||||
return func, None, ser, ser
|
||||
|
||||
def extract_key_value_indexes(grouped_arg_offsets):
|
||||
"""
|
||||
Helper function to extract the key and value indexes from arg_offsets for the grouped and
|
||||
cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.
|
||||
|
||||
:param grouped_arg_offsets: List containing the key and value indexes of columns of the
|
||||
DataFrames to be passed to the udf. It consists of n repeating groups where n is the
|
||||
number of DataFrames. Each group has the following format:
|
||||
group[0]: length of group
|
||||
group[1]: length of key indexes
|
||||
group[2.. group[1] +2]: key attributes
|
||||
group[group[1] +3 group[0]]: value attributes
|
||||
"""
|
||||
parsed = []
|
||||
idx = 0
|
||||
while idx < len(grouped_arg_offsets):
|
||||
offsets_len = grouped_arg_offsets[idx]
|
||||
idx += 1
|
||||
offsets = grouped_arg_offsets[idx: idx + offsets_len]
|
||||
split_index = offsets[0] + 1
|
||||
offset_keys = offsets[1: split_index]
|
||||
offset_values = offsets[split_index:]
|
||||
parsed.append([offset_keys, offset_values])
|
||||
idx += offsets_len
|
||||
return parsed
|
||||
|
||||
udfs = {}
|
||||
call_udf = []
|
||||
mapper_str = ""
|
||||
|
@ -359,10 +419,24 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
arg_offsets, udf = read_single_udf(
|
||||
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
||||
udfs['f'] = udf
|
||||
split_offset = arg_offsets[0] + 1
|
||||
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
|
||||
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
|
||||
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
|
||||
parsed_offsets = extract_key_value_indexes(arg_offsets)
|
||||
keys = ["a[%d]" % (o,) for o in parsed_offsets[0][0]]
|
||||
vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]]
|
||||
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys), ", ".join(vals))
|
||||
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||
# We assume there is only one UDF here because cogrouped map doesn't
|
||||
# support combining multiple UDFs.
|
||||
assert num_udfs == 1
|
||||
arg_offsets, udf = read_single_udf(
|
||||
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
||||
udfs['f'] = udf
|
||||
parsed_offsets = extract_key_value_indexes(arg_offsets)
|
||||
df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]]
|
||||
df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]]
|
||||
df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]]
|
||||
df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]]
|
||||
mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % (
|
||||
", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys), ", ".join(df2_vals))
|
||||
else:
|
||||
# Create function like this:
|
||||
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
|
||||
|
|
|
@ -1191,6 +1191,12 @@ class Analyzer(
|
|||
// To resolve duplicate expression IDs for Join and Intersect
|
||||
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
|
||||
j.copy(right = dedupRight(left, right))
|
||||
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
|
||||
val leftRes = leftAttributes
|
||||
.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
|
||||
val rightRes = rightAttributes
|
||||
.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
|
||||
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
|
||||
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
|
||||
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
|
||||
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
|
|||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
|
||||
|
||||
/**
|
||||
* FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
|
||||
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
|
||||
* This is used by DataFrame.groupby().apply().
|
||||
*/
|
||||
case class FlatMapGroupsInPandas(
|
||||
|
@ -40,7 +40,7 @@ case class FlatMapGroupsInPandas(
|
|||
}
|
||||
|
||||
/**
|
||||
* Map partitions using an udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
|
||||
* Map partitions using a udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
|
||||
* This is used by DataFrame.mapInPandas()
|
||||
*/
|
||||
case class MapInPandas(
|
||||
|
@ -51,6 +51,21 @@ case class MapInPandas(
|
|||
override val producedAttributes = AttributeSet(output)
|
||||
}
|
||||
|
||||
/**
|
||||
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe
|
||||
* This is used by DataFrame.groupby().cogroup().apply().
|
||||
*/
|
||||
case class FlatMapCoGroupsInPandas(
|
||||
leftAttributes: Seq[Attribute],
|
||||
rightAttributes: Seq[Attribute],
|
||||
functionExpr: Expression,
|
||||
output: Seq[Attribute],
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan) extends BinaryNode {
|
||||
|
||||
override val producedAttributes = AttributeSet(output)
|
||||
}
|
||||
|
||||
trait BaseEvalPython extends UnaryNode {
|
||||
|
||||
def udfs: Seq[PythonUDF]
|
||||
|
|
|
@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
|
|||
*/
|
||||
@Stable
|
||||
class RelationalGroupedDataset protected[sql](
|
||||
df: DataFrame,
|
||||
groupingExprs: Seq[Expression],
|
||||
val df: DataFrame,
|
||||
val groupingExprs: Seq[Expression],
|
||||
groupType: RelationalGroupedDataset.GroupType) {
|
||||
|
||||
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
|
||||
|
@ -523,6 +523,48 @@ class RelationalGroupedDataset protected[sql](
|
|||
Dataset.ofRows(df.sparkSession, plan)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a vectorized python user-defined function to each cogrouped data.
|
||||
* The user-defined function defines a transformation:
|
||||
* `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
|
||||
* For each group in the cogrouped data, all elements in the group are passed as a
|
||||
* `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]].
|
||||
*
|
||||
* This function uses Apache Arrow as serialization format between Java executors and Python
|
||||
* workers.
|
||||
*/
|
||||
private[sql] def flatMapCoGroupsInPandas(
|
||||
r: RelationalGroupedDataset,
|
||||
expr: PythonUDF): DataFrame = {
|
||||
require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||
"Must pass a cogrouped map udf")
|
||||
require(expr.dataType.isInstanceOf[StructType],
|
||||
s"The returnType of the udf must be a ${StructType.simpleString}")
|
||||
|
||||
val leftGroupingNamedExpressions = groupingExprs.map {
|
||||
case ne: NamedExpression => ne
|
||||
case other => Alias(other, other.toString)()
|
||||
}
|
||||
|
||||
val rightGroupingNamedExpressions = r.groupingExprs.map {
|
||||
case ne: NamedExpression => ne
|
||||
case other => Alias(other, other.toString)()
|
||||
}
|
||||
|
||||
val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
|
||||
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
|
||||
|
||||
val leftChild = df.logicalPlan
|
||||
val rightChild = r.df.logicalPlan
|
||||
|
||||
val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
|
||||
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
|
||||
|
||||
val output = expr.dataType.asInstanceOf[StructType].toAttributes
|
||||
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
|
||||
Dataset.ofRows(df.sparkSession, plan)
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
val builder = new StringBuilder
|
||||
builder.append("RelationalGroupedDataset: [grouping expressions: [")
|
||||
|
|
|
@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
f, p, b, is, ot, planLater(child)) :: Nil
|
||||
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
|
||||
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
|
||||
case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) =>
|
||||
execution.python.FlatMapCoGroupsInPandasExec(
|
||||
leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil
|
||||
case logical.MapInPandas(func, output, child) =>
|
||||
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
|
||||
case logical.MapElements(f, _, _, objAttr, child) =>
|
||||
|
|
|
@ -19,12 +19,9 @@ package org.apache.spark.sql.execution.python
|
|||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.arrow.vector.VectorSchemaRoot
|
||||
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
|
||||
import org.apache.arrow.vector.ipc.ArrowStreamWriter
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.api.python._
|
||||
|
@ -33,7 +30,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -46,7 +42,7 @@ class ArrowPythonRunner(
|
|||
schema: StructType,
|
||||
timeZoneId: String,
|
||||
conf: Map[String, String])
|
||||
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
|
||||
extends BaseArrowPythonRunner[Iterator[InternalRow]](
|
||||
funcs, evalType, argOffsets) {
|
||||
|
||||
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
|
||||
|
@ -119,72 +115,4 @@ class ArrowPythonRunner(
|
|||
}
|
||||
}
|
||||
|
||||
protected override def newReaderIterator(
|
||||
stream: DataInputStream,
|
||||
writerThread: WriterThread,
|
||||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[ColumnarBatch] = {
|
||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||
|
||||
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
||||
|
||||
private var reader: ArrowStreamReader = _
|
||||
private var root: VectorSchemaRoot = _
|
||||
private var schema: StructType = _
|
||||
private var vectors: Array[ColumnVector] = _
|
||||
|
||||
context.addTaskCompletionListener[Unit] { _ =>
|
||||
if (reader != null) {
|
||||
reader.close(false)
|
||||
}
|
||||
allocator.close()
|
||||
}
|
||||
|
||||
private var batchLoaded = true
|
||||
|
||||
protected override def read(): ColumnarBatch = {
|
||||
if (writerThread.exception.isDefined) {
|
||||
throw writerThread.exception.get
|
||||
}
|
||||
try {
|
||||
if (reader != null && batchLoaded) {
|
||||
batchLoaded = reader.loadNextBatch()
|
||||
if (batchLoaded) {
|
||||
val batch = new ColumnarBatch(vectors)
|
||||
batch.setNumRows(root.getRowCount)
|
||||
batch
|
||||
} else {
|
||||
reader.close(false)
|
||||
allocator.close()
|
||||
// Reach end of stream. Call `read()` again to read control data.
|
||||
read()
|
||||
}
|
||||
} else {
|
||||
stream.readInt() match {
|
||||
case SpecialLengths.START_ARROW_STREAM =>
|
||||
reader = new ArrowStreamReader(stream, allocator)
|
||||
root = reader.getVectorSchemaRoot()
|
||||
schema = ArrowUtils.fromArrowSchema(root.getSchema())
|
||||
vectors = root.getFieldVectors().asScala.map { vector =>
|
||||
new ArrowColumnVector(vector)
|
||||
}.toArray[ColumnVector]
|
||||
read()
|
||||
case SpecialLengths.TIMING_DATA =>
|
||||
handleTimingData()
|
||||
read()
|
||||
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
|
||||
throw handlePythonException()
|
||||
case SpecialLengths.END_OF_DATA_SECTION =>
|
||||
handleEndOfDataSection()
|
||||
null
|
||||
}
|
||||
}
|
||||
} catch handleException
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.arrow.vector.VectorSchemaRoot
|
||||
import org.apache.arrow.vector.ipc.ArrowStreamReader
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.api.python._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
||||
|
||||
/**
|
||||
* Common functionality for a udf runner that exchanges data with Python worker via Arrow stream.
|
||||
*/
|
||||
abstract class BaseArrowPythonRunner[T](
|
||||
funcs: Seq[ChainedPythonFunctions],
|
||||
evalType: Int,
|
||||
argOffsets: Array[Array[Int]])
|
||||
extends BasePythonRunner[T, ColumnarBatch](funcs, evalType, argOffsets) {
|
||||
|
||||
protected override def newReaderIterator(
|
||||
stream: DataInputStream,
|
||||
writerThread: WriterThread,
|
||||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[ColumnarBatch] = {
|
||||
|
||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||
|
||||
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
||||
|
||||
private var reader: ArrowStreamReader = _
|
||||
private var root: VectorSchemaRoot = _
|
||||
private var schema: StructType = _
|
||||
private var vectors: Array[ColumnVector] = _
|
||||
|
||||
context.addTaskCompletionListener[Unit] { _ =>
|
||||
if (reader != null) {
|
||||
reader.close(false)
|
||||
}
|
||||
allocator.close()
|
||||
}
|
||||
|
||||
private var batchLoaded = true
|
||||
|
||||
protected override def read(): ColumnarBatch = {
|
||||
if (writerThread.exception.isDefined) {
|
||||
throw writerThread.exception.get
|
||||
}
|
||||
try {
|
||||
if (reader != null && batchLoaded) {
|
||||
batchLoaded = reader.loadNextBatch()
|
||||
if (batchLoaded) {
|
||||
val batch = new ColumnarBatch(vectors)
|
||||
batch.setNumRows(root.getRowCount)
|
||||
batch
|
||||
} else {
|
||||
reader.close(false)
|
||||
allocator.close()
|
||||
// Reach end of stream. Call `read()` again to read control data.
|
||||
read()
|
||||
}
|
||||
} else {
|
||||
stream.readInt() match {
|
||||
case SpecialLengths.START_ARROW_STREAM =>
|
||||
reader = new ArrowStreamReader(stream, allocator)
|
||||
root = reader.getVectorSchemaRoot()
|
||||
schema = ArrowUtils.fromArrowSchema(root.getSchema())
|
||||
vectors = root.getFieldVectors().asScala.map { vector =>
|
||||
new ArrowColumnVector(vector)
|
||||
}.toArray[ColumnVector]
|
||||
read()
|
||||
case SpecialLengths.TIMING_DATA =>
|
||||
handleTimingData()
|
||||
read()
|
||||
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
|
||||
throw handlePythonException()
|
||||
case SpecialLengths.END_OF_DATA_SECTION =>
|
||||
handleEndOfDataSection()
|
||||
null
|
||||
}
|
||||
}
|
||||
} catch handleException
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection}
|
||||
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
|
||||
|
||||
/**
|
||||
* Base functionality for plans which execute grouped python udfs.
|
||||
*/
|
||||
abstract class BasePandasGroupExec(
|
||||
func: Expression,
|
||||
output: Seq[Attribute])
|
||||
extends SparkPlan {
|
||||
|
||||
protected val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
|
||||
protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
|
||||
|
||||
protected val pandasFunction = func.asInstanceOf[PythonUDF].func
|
||||
|
||||
protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||
|
||||
override def producedAttributes: AttributeSet = AttributeSet(output)
|
||||
|
||||
/**
|
||||
* passes the data to the python runner and coverts the resulting
|
||||
* columnarbatch into internal rows.
|
||||
*/
|
||||
protected def executePython[T](
|
||||
data: Iterator[T],
|
||||
runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = {
|
||||
|
||||
val context = TaskContext.get()
|
||||
val columnarBatchIter = runner.compute(data, context.partitionId(), context)
|
||||
val unsafeProj = UnsafeProjection.create(output, output)
|
||||
|
||||
columnarBatchIter.flatMap { batch =>
|
||||
// UDF returns a StructType column in ColumnarBatch, select the children here
|
||||
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
|
||||
val outputVectors = output.indices.map(structVector.getChild)
|
||||
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
|
||||
flattenedBatch.setNumRows(batch.numRows())
|
||||
flattenedBatch.rowIterator.asScala
|
||||
}.map(unsafeProj)
|
||||
}
|
||||
|
||||
/**
|
||||
* groups according to grouping attributes and then projects into the deduplicated schema
|
||||
*/
|
||||
protected def groupAndProject(
|
||||
input: Iterator[InternalRow],
|
||||
groupingAttributes: Seq[Attribute],
|
||||
inputSchema: Seq[Attribute],
|
||||
dedupSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
|
||||
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
|
||||
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
|
||||
groupedIter.map {
|
||||
case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a the deduplicated attributes of the spark plan and the arg offsets of the
|
||||
* keys and values.
|
||||
*
|
||||
* The deduplicated attributes are needed because the spark plan may contain an attribute
|
||||
* twice; once in the key and once in the value. For any such attribute we need to
|
||||
* deduplicate.
|
||||
*
|
||||
* The arg offsets are used to distinguish grouping grouping attributes and data attributes
|
||||
* as following:
|
||||
*
|
||||
* argOffsets[0] is the length of the argOffsets array
|
||||
*
|
||||
* argOffsets[1] is the length of grouping attribute
|
||||
* argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping attributes
|
||||
*
|
||||
* argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
|
||||
*/
|
||||
protected def resolveArgOffsets(
|
||||
child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = {
|
||||
|
||||
val dataAttributes = child.output.drop(groupingAttributes.length)
|
||||
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
||||
dataAttributes.indexWhere(attribute.semanticEquals)
|
||||
}
|
||||
|
||||
val groupingArgOffsets = new ArrayBuffer[Int]
|
||||
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
||||
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
||||
|
||||
groupingAttributes.zip(groupingIndicesInData).foreach {
|
||||
case (attribute, index) =>
|
||||
if (index == -1) {
|
||||
groupingArgOffsets += nonDupGroupingAttributes.length
|
||||
nonDupGroupingAttributes += attribute
|
||||
} else {
|
||||
groupingArgOffsets += index + nonDupGroupingSize
|
||||
}
|
||||
}
|
||||
|
||||
val dataArgOffsets = nonDupGroupingAttributes.length until
|
||||
(nonDupGroupingAttributes.length + dataAttributes.length)
|
||||
|
||||
val argOffsetsLength = groupingAttributes.length + dataArgOffsets.length + 1
|
||||
val argOffsets = Array(argOffsetsLength,
|
||||
groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets
|
||||
|
||||
// Attributes after deduplication
|
||||
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
||||
(dedupAttributes, argOffsets)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
|
||||
import org.apache.arrow.vector.VectorSchemaRoot
|
||||
import org.apache.arrow.vector.ipc.ArrowStreamWriter
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.api.python._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.execution.arrow.ArrowWriter
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
/**
|
||||
* Python UDF Runner for cogrouped udfs. Although the data is exchanged with the python
|
||||
* worker via arrow, we cannot use `ArrowPythonRunner` as we need to send more than one
|
||||
* dataframe.
|
||||
*/
|
||||
class CogroupedArrowPythonRunner(
|
||||
funcs: Seq[ChainedPythonFunctions],
|
||||
evalType: Int,
|
||||
argOffsets: Array[Array[Int]],
|
||||
leftSchema: StructType,
|
||||
rightSchema: StructType,
|
||||
timeZoneId: String,
|
||||
conf: Map[String, String])
|
||||
extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])](
|
||||
funcs, evalType, argOffsets) {
|
||||
|
||||
protected def newWriterThread(
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
|
||||
partitionIndex: Int,
|
||||
context: TaskContext): WriterThread = {
|
||||
|
||||
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
|
||||
|
||||
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
|
||||
|
||||
// Write config for the worker as a number of key -> value pairs of strings
|
||||
dataOut.writeInt(conf.size)
|
||||
for ((k, v) <- conf) {
|
||||
PythonRDD.writeUTF(k, dataOut)
|
||||
PythonRDD.writeUTF(v, dataOut)
|
||||
}
|
||||
|
||||
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
|
||||
}
|
||||
|
||||
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
|
||||
// For each we first send the number of dataframes in each group then send
|
||||
// first df, then send second df. End of data is marked by sending 0.
|
||||
while (inputIterator.hasNext) {
|
||||
dataOut.writeInt(2)
|
||||
val (nextLeft, nextRight) = inputIterator.next()
|
||||
writeGroup(nextLeft, leftSchema, dataOut, "left")
|
||||
writeGroup(nextRight, rightSchema, dataOut, "right")
|
||||
}
|
||||
dataOut.writeInt(0)
|
||||
}
|
||||
|
||||
def writeGroup(
|
||||
group: Iterator[InternalRow],
|
||||
schema: StructType,
|
||||
dataOut: DataOutputStream,
|
||||
name: String) = {
|
||||
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||
s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
|
||||
val root = VectorSchemaRoot.create(arrowSchema, allocator)
|
||||
|
||||
Utils.tryWithSafeFinally {
|
||||
val writer = new ArrowStreamWriter(root, null, dataOut)
|
||||
val arrowWriter = ArrowWriter.create(root)
|
||||
writer.start()
|
||||
|
||||
while (group.hasNext) {
|
||||
arrowWriter.write(group.next())
|
||||
}
|
||||
arrowWriter.finish()
|
||||
writer.writeBatch()
|
||||
writer.end()
|
||||
}{
|
||||
root.close()
|
||||
allocator.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import org.apache.spark.api.python.PythonEvalType
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
||||
/**
|
||||
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
|
||||
*
|
||||
* The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the
|
||||
* Python worker via Arrow. As each side of the cogroup may have a different schema we send every
|
||||
* group in its own Arrow stream.
|
||||
* The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the
|
||||
* user-defined function, and passes the resulting `pandas.DataFrame`
|
||||
* as an Arrow record batch. Finally, each record batch is turned to
|
||||
* Iterator[InternalRow] using ColumnarBatch.
|
||||
*
|
||||
* Note on memory usage:
|
||||
* Both the Python worker and the Java executor need to have enough memory to
|
||||
* hold the largest cogroup. The memory on the Java side is used to construct the
|
||||
* record batches (off heap memory). The memory on the Python side is used for
|
||||
* holding the `pandas.DataFrame`. It's possible to further split one group into
|
||||
* multiple record batches to reduce the memory footprint on the Java side, this
|
||||
* is left as future work.
|
||||
*/
|
||||
case class FlatMapCoGroupsInPandasExec(
|
||||
leftGroup: Seq[Attribute],
|
||||
rightGroup: Seq[Attribute],
|
||||
func: Expression,
|
||||
output: Seq[Attribute],
|
||||
left: SparkPlan,
|
||||
right: SparkPlan)
|
||||
extends BasePandasGroupExec(func, output) with BinaryExecNode {
|
||||
|
||||
override def outputPartitioning: Partitioning = left.outputPartitioning
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] = {
|
||||
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
|
||||
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
|
||||
leftDist :: rightDist :: Nil
|
||||
}
|
||||
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
|
||||
leftGroup
|
||||
.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
|
||||
}
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
|
||||
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
|
||||
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
|
||||
|
||||
// Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty
|
||||
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
|
||||
if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
|
||||
|
||||
val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup)
|
||||
val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup)
|
||||
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
|
||||
.map { case (_, l, r) => (l, r) }
|
||||
|
||||
val runner = new CogroupedArrowPythonRunner(
|
||||
chainedFunc,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||
Array(leftArgOffsets ++ rightArgOffsets),
|
||||
StructType.fromAttributes(leftDedup),
|
||||
StructType.fromAttributes(rightDedup),
|
||||
sessionLocalTimeZone,
|
||||
pythonRunnerConf)
|
||||
|
||||
executePython(data, runner)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,19 +17,14 @@
|
|||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
||||
import org.apache.spark.api.python.PythonEvalType
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
||||
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
|
||||
|
||||
|
||||
/**
|
||||
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
|
||||
|
@ -53,14 +48,10 @@ case class FlatMapGroupsInPandasExec(
|
|||
func: Expression,
|
||||
output: Seq[Attribute],
|
||||
child: SparkPlan)
|
||||
extends UnaryExecNode {
|
||||
|
||||
private val pandasFunction = func.asInstanceOf[PythonUDF].func
|
||||
extends BasePandasGroupExec(func, output) with UnaryExecNode {
|
||||
|
||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||
|
||||
override def producedAttributes: AttributeSet = AttributeSet(output)
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] = {
|
||||
if (groupingAttributes.isEmpty) {
|
||||
AllTuples :: Nil
|
||||
|
@ -75,88 +66,23 @@ case class FlatMapGroupsInPandasExec(
|
|||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
val inputRDD = child.execute()
|
||||
|
||||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
|
||||
|
||||
// Deduplicate the grouping attributes.
|
||||
// If a grouping attribute also appears in data attributes, then we don't need to send the
|
||||
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
|
||||
// then we need to send this grouping attribute to python worker.
|
||||
//
|
||||
// We use argOffsets to distinguish grouping attributes and data attributes as following:
|
||||
//
|
||||
// argOffsets[0] is the length of grouping attributes
|
||||
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
|
||||
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
|
||||
|
||||
val dataAttributes = child.output.drop(groupingAttributes.length)
|
||||
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
||||
dataAttributes.indexWhere(attribute.semanticEquals)
|
||||
}
|
||||
|
||||
val groupingArgOffsets = new ArrayBuffer[Int]
|
||||
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
||||
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
||||
|
||||
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
|
||||
// their offsets are 0, 1, 2 ...
|
||||
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
|
||||
// their offsets are n + index, where n is the total number of non duplicate grouping
|
||||
// attributes and index is the index in the data attributes that the grouping attribute
|
||||
// is a duplicate of.
|
||||
|
||||
groupingAttributes.zip(groupingIndicesInData).foreach {
|
||||
case (attribute, index) =>
|
||||
if (index == -1) {
|
||||
groupingArgOffsets += nonDupGroupingAttributes.length
|
||||
nonDupGroupingAttributes += attribute
|
||||
} else {
|
||||
groupingArgOffsets += index + nonDupGroupingSize
|
||||
}
|
||||
}
|
||||
|
||||
val dataArgOffsets = nonDupGroupingAttributes.length until
|
||||
(nonDupGroupingAttributes.length + dataAttributes.length)
|
||||
|
||||
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
|
||||
|
||||
// Attributes after deduplication
|
||||
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
||||
val dedupSchema = StructType.fromAttributes(dedupAttributes)
|
||||
val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes)
|
||||
|
||||
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
|
||||
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
|
||||
val grouped = if (groupingAttributes.isEmpty) {
|
||||
Iterator(iter)
|
||||
} else {
|
||||
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
|
||||
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
|
||||
groupedIter.map {
|
||||
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
|
||||
}
|
||||
}
|
||||
|
||||
val context = TaskContext.get()
|
||||
val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes)
|
||||
.map{case(_, x) => x}
|
||||
|
||||
val columnarBatchIter = new ArrowPythonRunner(
|
||||
val runner = new ArrowPythonRunner(
|
||||
chainedFunc,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
argOffsets,
|
||||
dedupSchema,
|
||||
Array(argOffsets),
|
||||
StructType.fromAttributes(dedupAttributes),
|
||||
sessionLocalTimeZone,
|
||||
pythonRunnerConf).compute(grouped, context.partitionId(), context)
|
||||
pythonRunnerConf)
|
||||
|
||||
val unsafeProj = UnsafeProjection.create(output, output)
|
||||
|
||||
columnarBatchIter.flatMap { batch =>
|
||||
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
|
||||
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
|
||||
val outputVectors = output.indices.map(structVector.getChild)
|
||||
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
|
||||
flattenedBatch.setNumRows(batch.numRows())
|
||||
flattenedBatch.rowIterator.asScala
|
||||
}.map(unsafeProj)
|
||||
executePython(data, runner)
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue