[SPARK-27463][PYTHON] Support Dataframe Cogroup via Pandas UDFs

### What changes were proposed in this pull request?

Adds a new cogroup Pandas UDF.  This allows two grouped dataframes to be cogrouped together and apply a (pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame UDF to each cogroup.

**Example usage**

```
from pyspark.sql.functions import pandas_udf, PandasUDFType
df1 = spark.createDataFrame(
   [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
   ("time", "id", "v1"))

df2 = spark.createDataFrame(
   [(20000101, 1, "x"), (20000101, 2, "y")],
    ("time", "id", "v2"))

pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
   def asof_join(l, r):
      return pd.merge_asof(l, r, on="time", by="id")

df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()

```

        +--------+---+---+---+
        |    time| id| v1| v2|
        +--------+---+---+---+
        |20000101|  1|1.0|  x|
        |20000102|  1|3.0|  x|
        |20000101|  2|2.0|  y|
        |20000102|  2|4.0|  y|
        +--------+---+---+---+

### How was this patch tested?

Added unit test test_pandas_udf_cogrouped_map

Closes #24981 from d80tb7/SPARK-27463-poc-arrow-stream.

Authored-by: Chris Martin <chris@cmartinit.co.uk>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
This commit is contained in:
Chris Martin 2019-09-17 17:13:50 -07:00 committed by Bryan Cutler
parent 197732e1f4
commit 05988b256e
19 changed files with 1070 additions and 178 deletions

View file

@ -48,6 +48,7 @@ private[spark] object PythonEvalType {
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
}
}

View file

@ -75,6 +75,7 @@ class PythonEvalType(object):
SQL_WINDOW_AGG_PANDAS_UDF = 203
SQL_SCALAR_PANDAS_ITER_UDF = 204
SQL_MAP_PANDAS_ITER_UDF = 205
SQL_COGROUPED_MAP_PANDAS_UDF = 206
def portable_hash(x):

View file

@ -401,6 +401,32 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
return "ArrowStreamPandasUDFSerializer"
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
lists of pandas.Series.
"""
import pyarrow as pa
dataframes_in_group = None
while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)
if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
)
elif dataframes_in_group != 0:
raise ValueError(
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
class BatchedSerializer(Serializer):
"""

View file

@ -0,0 +1,98 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark import since
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
class CoGroupedData(object):
"""
A logical grouping of two :class:`GroupedData`,
created by :func:`GroupedData.cogroup`.
.. note:: Experimental
.. versionadded:: 3.0
"""
def __init__(self, gd1, gd2):
self._gd1 = gd1
self._gd2 = gd2
self.sql_ctx = gd1.sql_ctx
@since(3.0)
def apply(self, udf):
"""
Applies a function to each cogroup using a pandas udf and returns the result
as a `DataFrame`.
The user-defined function should take two `pandas.DataFrame` and return another
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
are combined as a :class:`DataFrame`.
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
.. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
.. note:: Experimental
:param udf: a cogrouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
|20000101| 2|2.0| y|
|20000102| 2|4.0| y|
+--------+---+---+---+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"COGROUPED_MAP.")
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)
@staticmethod
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]

View file

@ -2814,6 +2814,8 @@ class PandasUDFType(object):
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
@ -3320,7 +3322,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")

View file

@ -22,6 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
from pyspark.sql.cogroup import CoGroupedData
__all__ = ["GroupedData"]
@ -218,6 +219,15 @@ class GroupedData(object):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self._df)
@since(3.0)
def cogroup(self, other):
"""
Cogroups this group with another group so that we can run cogrouped operations.
See :class:`CoGroupedData` for the operations that can be run.
"""
return CoGroupedData(self, other)
@since(2.3)
def apply(self, udf):
"""
@ -232,7 +242,7 @@ class GroupedData(object):
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
.. note:: This function requires a full shuffle. all the data of a group will be loaded
.. note:: This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.

View file

@ -0,0 +1,280 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import sys
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal, assert_series_equal
if have_pyarrow:
import pyarrow as pa
"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True
@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class CoGroupedMapPandasUDFTests(ReusedSQLTestCase):
@property
def data1(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks')))\
.withColumn("v", col('k') * 10)\
.drop('ks')
@property
def data2(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks'))) \
.withColumn("v2", col('k') * 100) \
.drop('ks')
def test_simple(self):
self._test_merge(self.data1, self.data2)
def test_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_merge(left, self.data2)
def test_right_group_empty(self):
right = self.data2.where(col("id") % 2 == 0)
self._test_merge(self.data1, right)
def test_different_schemas(self):
right = self.data2.withColumn('v3', lit('a'))
self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
def test_complex_group_by(self):
left = pd.DataFrame.from_dict({
'id': [1, 2, 3],
'k': [5, 6, 7],
'v': [9, 10, 11]
})
right = pd.DataFrame.from_dict({
'id': [11, 12, 13],
'k': [5, 6, 7],
'v2': [90, 100, 110]
})
left_gdf = self.spark\
.createDataFrame(left)\
.groupby(col('id') % 2 == 0)
right_gdf = self.spark \
.createDataFrame(right) \
.groupby(col('id') % 2 == 0)
@pandas_udf('k long, v long, v2 long', PandasUDFType.COGROUPED_MAP)
def merge_pandas(l, r):
return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
result = left_gdf \
.cogroup(right_gdf) \
.apply(merge_pandas) \
.sort(['k']) \
.toPandas()
expected = pd.DataFrame.from_dict({
'k': [5, 6, 7],
'v': [9, 10, 11],
'v2': [90, 100, 110]
})
assert_frame_equal(expected, result, check_column_type=_check_column_type)
def test_empty_group_by(self):
left = self.data1
right = self.data2
@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
def merge_pandas(l, r):
return pd.merge(l, r, on=['id', 'k'])
result = left.groupby().cogroup(right.groupby())\
.apply(merge_pandas) \
.sort(['id', 'k']) \
.toPandas()
left = left.toPandas()
right = right.toPandas()
expected = pd \
.merge(left, right, on=['id', 'k']) \
.sort_values(by=['id', 'k'])
assert_frame_equal(expected, result, check_column_type=_check_column_type)
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
result = df.groupby().cogroup(df.groupby())\
.apply(pandas_udf(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
'sum1 int, sum2 int',
PandasUDFType.COGROUPED_MAP)).collect()
self.assertEquals(result[0]['sum1'], 165)
self.assertEquals(result[0]['sum2'], 165)
def test_with_key_left(self):
self._test_with_key(self.data1, self.data1, isLeft=True)
def test_with_key_right(self):
self._test_with_key(self.data1, self.data1, isLeft=False)
def test_with_key_left_group_empty(self):
left = self.data1.where(col("id") % 2 == 0)
self._test_with_key(left, self.data1, isLeft=True)
def test_with_key_right_group_empty(self):
right = self.data1.where(col("id") % 2 == 0)
self._test_with_key(self.data1, right, isLeft=False)
def test_with_key_complex(self):
@pandas_udf('id long, k int, v int, key boolean', PandasUDFType.COGROUPED_MAP)
def left_assign_key(key, l, _):
return l.assign(key=key[0])
result = self.data1 \
.groupby(col('id') % 2 == 0)\
.cogroup(self.data2.groupby(col('id') % 2 == 0)) \
.apply(left_assign_key) \
.sort(['id', 'k']) \
.toPandas()
expected = self.data1.toPandas()
expected = expected.assign(key=expected.id % 2 == 0)
assert_frame_equal(expected, result, check_column_type=_check_column_type)
def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid returnType.*cogrouped map Pandas UDF.*MapType'):
pandas_udf(
lambda l, r: l,
'id long, v map<int, int>',
PandasUDFType.COGROUPED_MAP)
def test_wrong_args(self):
# Test that we get a sensible exception invalid values passed to apply
left = self.data1
right = self.data2
with QuietTest(self.sc):
# Function rather than a udf
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
left.groupby('id').cogroup(right.groupby('id')).apply(lambda l, r: l)
# Udf missing return type
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
left.groupby('id').cogroup(right.groupby('id'))\
.apply(udf(lambda l, r: l, DoubleType()))
# Pass in expression rather than udf
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
left.groupby('id').cogroup(right.groupby('id')).apply(left.v + 1)
# Zero arg function
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
left.groupby('id').cogroup(right.groupby('id'))\
.apply(pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
# Udf without PandasUDFType
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
left.groupby('id').cogroup(right.groupby('id'))\
.apply(pandas_udf(lambda x, y: x, DoubleType()))
# Udf with incorrect PandasUDFType
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*COGROUPED_MAP'):
left.groupby('id').cogroup(right.groupby('id'))\
.apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
@staticmethod
def _test_with_key(left, right, isLeft):
@pandas_udf('id long, k int, v int, key long', PandasUDFType.COGROUPED_MAP)
def right_assign_key(key, l, r):
return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
result = left \
.groupby('id') \
.cogroup(right.groupby('id')) \
.apply(right_assign_key) \
.toPandas()
expected = left.toPandas() if isLeft else right.toPandas()
expected = expected.assign(key=expected.id)
assert_frame_equal(expected, result, check_column_type=_check_column_type)
@staticmethod
def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'):
@pandas_udf(output_schema, PandasUDFType.COGROUPED_MAP)
def merge_pandas(l, r):
return pd.merge(l, r, on=['id', 'k'])
result = left \
.groupby('id') \
.cogroup(right.groupby('id')) \
.apply(merge_pandas)\
.sort(['id', 'k']) \
.toPandas()
left = left.toPandas()
right = right.toPandas()
expected = pd \
.merge(left, right, on=['id', 'k']) \
.sort_values(by=['id', 'k'])
assert_frame_equal(expected, result, check_column_type=_check_column_type)
if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_cogrouped_map import *
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -42,6 +42,7 @@ def _create_udf(f, returnType, evalType):
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
@ -65,6 +66,13 @@ def _create_udf(f, returnType, evalType):
"Invalid function: pandas_udfs with function type GROUPED_MAP "
"must take either one argument (data) or two arguments (key, data).")
if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF \
and len(argspec.args) not in (2, 3):
raise ValueError(
"Invalid function: pandas_udfs with function type COGROUPED_MAP "
"must take either two arguments (left, right) "
"or three arguments (key, left, right).")
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
@ -147,6 +155,17 @@ class UserDefinedFunction(object):
else:
raise TypeError("Invalid returnType for map iterator Pandas "
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with cogrouped map Pandas UDFs: "
"%s is not supported" % str(self._returnType_placeholder))
else:
raise TypeError("Invalid returnType for cogrouped map Pandas "
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
# StructType is not yet allowed as a return type, explicitly check here to fail fast

View file

@ -39,7 +39,7 @@ from pyspark.resourceinformation import ResourceInformation
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasUDFSerializer
BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
@ -121,6 +121,33 @@ def wrap_pandas_iter_udf(f, return_type):
map(verify_result_type, f(*iterator)))
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
import pandas as pd
left_df = pd.concat(left_value_series, axis=1)
right_df = pd.concat(right_value_series, axis=1)
if len(argspec.args) == 2:
result = f(left_df, right_df)
elif len(argspec.args) == 3:
key_series = left_key_series if not left_df.empty else right_key_series
key = tuple(s[0] for s in key_series)
result = f(key, left_df, right_df)
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
if not len(result.columns) == len(return_type):
raise RuntimeError(
"Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. "
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
return result
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))]
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
@ -244,6 +271,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@ -258,6 +288,7 @@ def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
@ -280,6 +311,9 @@ def read_udfs(pickleSer, infile, eval_type):
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
.lower() == "true"
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
else:
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
@ -343,6 +377,32 @@ def read_udfs(pickleSer, infile, eval_type):
# profiling is not supported for UDF
return func, None, ser, ser
def extract_key_value_indexes(grouped_arg_offsets):
"""
Helper function to extract the key and value indexes from arg_offsets for the grouped and
cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.
:param grouped_arg_offsets: List containing the key and value indexes of columns of the
DataFrames to be passed to the udf. It consists of n repeating groups where n is the
number of DataFrames. Each group has the following format:
group[0]: length of group
group[1]: length of key indexes
group[2.. group[1] +2]: key attributes
group[group[1] +3 group[0]]: value attributes
"""
parsed = []
idx = 0
while idx < len(grouped_arg_offsets):
offsets_len = grouped_arg_offsets[idx]
idx += 1
offsets = grouped_arg_offsets[idx: idx + offsets_len]
split_index = offsets[0] + 1
offset_keys = offsets[1: split_index]
offset_values = offsets[split_index:]
parsed.append([offset_keys, offset_values])
idx += offsets_len
return parsed
udfs = {}
call_udf = []
mapper_str = ""
@ -359,10 +419,24 @@ def read_udfs(pickleSer, infile, eval_type):
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
split_offset = arg_offsets[0] + 1
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
parsed_offsets = extract_key_value_indexes(arg_offsets)
keys = ["a[%d]" % (o,) for o in parsed_offsets[0][0]]
vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]]
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys), ", ".join(vals))
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
parsed_offsets = extract_key_value_indexes(arg_offsets)
df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]]
df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]]
df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]]
df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]]
mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % (
", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys), ", ".join(df2_vals))
else:
# Create function like this:
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))

View file

@ -1191,6 +1191,12 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
val leftRes = leftAttributes
.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
val rightRes = rightAttributes
.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
/**
* FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
* This is used by DataFrame.groupby().apply().
*/
case class FlatMapGroupsInPandas(
@ -40,7 +40,7 @@ case class FlatMapGroupsInPandas(
}
/**
* Map partitions using an udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
* Map partitions using a udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
* This is used by DataFrame.mapInPandas()
*/
case class MapInPandas(
@ -51,6 +51,21 @@ case class MapInPandas(
override val producedAttributes = AttributeSet(output)
}
/**
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe
* This is used by DataFrame.groupby().cogroup().apply().
*/
case class FlatMapCoGroupsInPandas(
leftAttributes: Seq[Attribute],
rightAttributes: Seq[Attribute],
functionExpr: Expression,
output: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
override val producedAttributes = AttributeSet(output)
}
trait BaseEvalPython extends UnaryNode {
def udfs: Seq[PythonUDF]

View file

@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
*/
@Stable
class RelationalGroupedDataset protected[sql](
df: DataFrame,
groupingExprs: Seq[Expression],
val df: DataFrame,
val groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType) {
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
@ -523,6 +523,48 @@ class RelationalGroupedDataset protected[sql](
Dataset.ofRows(df.sparkSession, plan)
}
/**
* Applies a vectorized python user-defined function to each cogrouped data.
* The user-defined function defines a transformation:
* `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
* For each group in the cogrouped data, all elements in the group are passed as a
* `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]].
*
* This function uses Apache Arrow as serialization format between Java executors and Python
* workers.
*/
private[sql] def flatMapCoGroupsInPandas(
r: RelationalGroupedDataset,
expr: PythonUDF): DataFrame = {
require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
"Must pass a cogrouped map udf")
require(expr.dataType.isInstanceOf[StructType],
s"The returnType of the udf must be a ${StructType.simpleString}")
val leftGroupingNamedExpressions = groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val rightGroupingNamedExpressions = r.groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
val leftChild = df.logicalPlan
val rightChild = r.df.logicalPlan
val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
Dataset.ofRows(df.sparkSession, plan)
}
override def toString: String = {
val builder = new StringBuilder
builder.append("RelationalGroupedDataset: [grouping expressions: [")

View file

@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, p, b, is, ot, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) =>
execution.python.FlatMapCoGroupsInPandasExec(
leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child) =>
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>

View file

@ -19,12 +19,9 @@ package org.apache.spark.sql.execution.python
import java.io._
import java.net._
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark._
import org.apache.spark.api.python._
@ -33,7 +30,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
/**
@ -46,7 +42,7 @@ class ArrowPythonRunner(
schema: StructType,
timeZoneId: String,
conf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
extends BaseArrowPythonRunner[Iterator[InternalRow]](
funcs, evalType, argOffsets) {
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
@ -119,72 +115,4 @@ class ArrowPythonRunner(
}
}
protected override def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
startTime: Long,
env: SparkEnv,
worker: Socket,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)
private var reader: ArrowStreamReader = _
private var root: VectorSchemaRoot = _
private var schema: StructType = _
private var vectors: Array[ColumnVector] = _
context.addTaskCompletionListener[Unit] { _ =>
if (reader != null) {
reader.close(false)
}
allocator.close()
}
private var batchLoaded = true
protected override def read(): ColumnarBatch = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
if (reader != null && batchLoaded) {
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
batch.setNumRows(root.getRowCount)
batch
} else {
reader.close(false)
allocator.close()
// Reach end of stream. Call `read()` again to read control data.
read()
}
} else {
stream.readInt() match {
case SpecialLengths.START_ARROW_STREAM =>
reader = new ArrowStreamReader(stream, allocator)
root = reader.getVectorSchemaRoot()
schema = ArrowUtils.fromArrowSchema(root.getSchema())
vectors = root.getFieldVectors().asScala.map { vector =>
new ArrowColumnVector(vector)
}.toArray[ColumnVector]
read()
case SpecialLengths.TIMING_DATA =>
handleTimingData()
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
throw handlePythonException()
case SpecialLengths.END_OF_DATA_SECTION =>
handleEndOfDataSection()
null
}
}
} catch handleException
}
}
}
}

View file

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python
import java.io._
import java.net._
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
/**
* Common functionality for a udf runner that exchanges data with Python worker via Arrow stream.
*/
abstract class BaseArrowPythonRunner[T](
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]])
extends BasePythonRunner[T, ColumnarBatch](funcs, evalType, argOffsets) {
protected override def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
startTime: Long,
env: SparkEnv,
worker: Socket,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)
private var reader: ArrowStreamReader = _
private var root: VectorSchemaRoot = _
private var schema: StructType = _
private var vectors: Array[ColumnVector] = _
context.addTaskCompletionListener[Unit] { _ =>
if (reader != null) {
reader.close(false)
}
allocator.close()
}
private var batchLoaded = true
protected override def read(): ColumnarBatch = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
if (reader != null && batchLoaded) {
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
batch.setNumRows(root.getRowCount)
batch
} else {
reader.close(false)
allocator.close()
// Reach end of stream. Call `read()` again to read control data.
read()
}
} else {
stream.readInt() match {
case SpecialLengths.START_ARROW_STREAM =>
reader = new ArrowStreamReader(stream, allocator)
root = reader.getVectorSchemaRoot()
schema = ArrowUtils.fromArrowSchema(root.getSchema())
vectors = root.getFieldVectors().asScala.map { vector =>
new ArrowColumnVector(vector)
}.toArray[ColumnVector]
read()
case SpecialLengths.TIMING_DATA =>
handleTimingData()
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
throw handlePythonException()
case SpecialLengths.END_OF_DATA_SECTION =>
handleEndOfDataSection()
null
}
}
} catch handleException
}
}
}
}

View file

@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
/**
* Base functionality for plans which execute grouped python udfs.
*/
abstract class BasePandasGroupExec(
func: Expression,
output: Seq[Attribute])
extends SparkPlan {
protected val sessionLocalTimeZone = conf.sessionLocalTimeZone
protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
protected val pandasFunction = func.asInstanceOf[PythonUDF].func
protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
override def producedAttributes: AttributeSet = AttributeSet(output)
/**
* passes the data to the python runner and coverts the resulting
* columnarbatch into internal rows.
*/
protected def executePython[T](
data: Iterator[T],
runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = {
val context = TaskContext.get()
val columnarBatchIter = runner.compute(data, context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap { batch =>
// UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
}
/**
* groups according to grouping attributes and then projects into the deduplicated schema
*/
protected def groupAndProject(
input: Iterator[InternalRow],
groupingAttributes: Seq[Attribute],
inputSchema: Seq[Attribute],
dedupSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
groupedIter.map {
case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj))
}
}
/**
* Returns a the deduplicated attributes of the spark plan and the arg offsets of the
* keys and values.
*
* The deduplicated attributes are needed because the spark plan may contain an attribute
* twice; once in the key and once in the value. For any such attribute we need to
* deduplicate.
*
* The arg offsets are used to distinguish grouping grouping attributes and data attributes
* as following:
*
* argOffsets[0] is the length of the argOffsets array
*
* argOffsets[1] is the length of grouping attribute
* argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping attributes
*
* argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
*/
protected def resolveArgOffsets(
child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = {
val dataAttributes = child.output.drop(groupingAttributes.length)
val groupingIndicesInData = groupingAttributes.map { attribute =>
dataAttributes.indexWhere(attribute.semanticEquals)
}
val groupingArgOffsets = new ArrayBuffer[Int]
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
groupingAttributes.zip(groupingIndicesInData).foreach {
case (attribute, index) =>
if (index == -1) {
groupingArgOffsets += nonDupGroupingAttributes.length
nonDupGroupingAttributes += attribute
} else {
groupingArgOffsets += index + nonDupGroupingSize
}
}
val dataArgOffsets = nonDupGroupingAttributes.length until
(nonDupGroupingAttributes.length + dataAttributes.length)
val argOffsetsLength = groupingAttributes.length + dataArgOffsets.length + 1
val argOffsets = Array(argOffsetsLength,
groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets
// Attributes after deduplication
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
(dedupAttributes, argOffsets)
}
}

View file

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python
import java.io._
import java.net._
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
/**
* Python UDF Runner for cogrouped udfs. Although the data is exchanged with the python
* worker via arrow, we cannot use `ArrowPythonRunner` as we need to send more than one
* dataframe.
*/
class CogroupedArrowPythonRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
conf: Map[String, String])
extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])](
funcs, evalType, argOffsets) {
protected def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
dataOut.writeInt(conf.size)
for ((k, v) <- conf) {
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
}
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
// For each we first send the number of dataframes in each group then send
// first df, then send second df. End of data is marked by sending 0.
while (inputIterator.hasNext) {
dataOut.writeInt(2)
val (nextLeft, nextRight) = inputIterator.next()
writeGroup(nextLeft, leftSchema, dataOut, "left")
writeGroup(nextRight, rightSchema, dataOut, "right")
}
dataOut.writeInt(0)
}
def writeGroup(
group: Iterator[InternalRow],
schema: StructType,
dataOut: DataOutputStream,
name: String) = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
Utils.tryWithSafeFinally {
val writer = new ArrowStreamWriter(root, null, dataOut)
val arrowWriter = ArrowWriter.create(root)
writer.start()
while (group.hasNext) {
arrowWriter.write(group.next())
}
arrowWriter.finish()
writer.writeBatch()
writer.end()
}{
root.close()
allocator.close()
}
}
}
}
}

View file

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.types.StructType
/**
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
*
* The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the
* Python worker via Arrow. As each side of the cogroup may have a different schema we send every
* group in its own Arrow stream.
* The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the
* user-defined function, and passes the resulting `pandas.DataFrame`
* as an Arrow record batch. Finally, each record batch is turned to
* Iterator[InternalRow] using ColumnarBatch.
*
* Note on memory usage:
* Both the Python worker and the Java executor need to have enough memory to
* hold the largest cogroup. The memory on the Java side is used to construct the
* record batches (off heap memory). The memory on the Python side is used for
* holding the `pandas.DataFrame`. It's possible to further split one group into
* multiple record batches to reduce the memory footprint on the Java side, this
* is left as future work.
*/
case class FlatMapCoGroupsInPandasExec(
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
func: Expression,
output: Seq[Attribute],
left: SparkPlan,
right: SparkPlan)
extends BasePandasGroupExec(func, output) with BinaryExecNode {
override def outputPartitioning: Partitioning = left.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
leftDist :: rightDist :: Nil
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
leftGroup
.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
}
override protected def doExecute(): RDD[InternalRow] = {
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
// Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup)
val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup)
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
.map { case (_, l, r) => (l, r) }
val runner = new CogroupedArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
Array(leftArgOffsets ++ rightArgOffsets),
StructType.fromAttributes(leftDedup),
StructType.fromAttributes(rightDedup),
sessionLocalTimeZone,
pythonRunnerConf)
executePython(data, runner)
}
}
}
}

View file

@ -17,19 +17,14 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
/**
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
@ -53,14 +48,10 @@ case class FlatMapGroupsInPandasExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan)
extends UnaryExecNode {
private val pandasFunction = func.asInstanceOf[PythonUDF].func
extends BasePandasGroupExec(func, output) with UnaryExecNode {
override def outputPartitioning: Partitioning = child.outputPartitioning
override def producedAttributes: AttributeSet = AttributeSet(output)
override def requiredChildDistribution: Seq[Distribution] = {
if (groupingAttributes.isEmpty) {
AllTuples :: Nil
@ -75,88 +66,23 @@ case class FlatMapGroupsInPandasExec(
override protected def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute()
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
// Deduplicate the grouping attributes.
// If a grouping attribute also appears in data attributes, then we don't need to send the
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
// then we need to send this grouping attribute to python worker.
//
// We use argOffsets to distinguish grouping attributes and data attributes as following:
//
// argOffsets[0] is the length of grouping attributes
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
val dataAttributes = child.output.drop(groupingAttributes.length)
val groupingIndicesInData = groupingAttributes.map { attribute =>
dataAttributes.indexWhere(attribute.semanticEquals)
}
val groupingArgOffsets = new ArrayBuffer[Int]
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
// their offsets are 0, 1, 2 ...
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
// their offsets are n + index, where n is the total number of non duplicate grouping
// attributes and index is the index in the data attributes that the grouping attribute
// is a duplicate of.
groupingAttributes.zip(groupingIndicesInData).foreach {
case (attribute, index) =>
if (index == -1) {
groupingArgOffsets += nonDupGroupingAttributes.length
nonDupGroupingAttributes += attribute
} else {
groupingArgOffsets += index + nonDupGroupingSize
}
}
val dataArgOffsets = nonDupGroupingAttributes.length until
(nonDupGroupingAttributes.length + dataAttributes.length)
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
// Attributes after deduplication
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
val dedupSchema = StructType.fromAttributes(dedupAttributes)
val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes)
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
val context = TaskContext.get()
val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes)
.map{case(_, x) => x}
val columnarBatchIter = new ArrowPythonRunner(
val runner = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
Array(argOffsets),
StructType.fromAttributes(dedupAttributes),
sessionLocalTimeZone,
pythonRunnerConf).compute(grouped, context.partitionId(), context)
pythonRunnerConf)
val unsafeProj = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap { batch =>
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
executePython(data, runner)
}}
}
}