[SPARK-27463][PYTHON] Support Dataframe Cogroup via Pandas UDFs
### What changes were proposed in this pull request? Adds a new cogroup Pandas UDF. This allows two grouped dataframes to be cogrouped together and apply a (pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame UDF to each cogroup. **Example usage** ``` from pyspark.sql.functions import pandas_udf, PandasUDFType df1 = spark.createDataFrame( [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ("time", "id", "v1")) df2 = spark.createDataFrame( [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) def asof_join(l, r): return pd.merge_asof(l, r, on="time", by="id") df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() ``` +--------+---+---+---+ | time| id| v1| v2| +--------+---+---+---+ |20000101| 1|1.0| x| |20000102| 1|3.0| x| |20000101| 2|2.0| y| |20000102| 2|4.0| y| +--------+---+---+---+ ### How was this patch tested? Added unit test test_pandas_udf_cogrouped_map Closes #24981 from d80tb7/SPARK-27463-poc-arrow-stream. Authored-by: Chris Martin <chris@cmartinit.co.uk> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
This commit is contained in:
parent
197732e1f4
commit
05988b256e
|
@ -48,6 +48,7 @@ private[spark] object PythonEvalType {
|
||||||
val SQL_WINDOW_AGG_PANDAS_UDF = 203
|
val SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||||
val SQL_SCALAR_PANDAS_ITER_UDF = 204
|
val SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||||
val SQL_MAP_PANDAS_ITER_UDF = 205
|
val SQL_MAP_PANDAS_ITER_UDF = 205
|
||||||
|
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
|
||||||
|
|
||||||
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
||||||
case NON_UDF => "NON_UDF"
|
case NON_UDF => "NON_UDF"
|
||||||
|
@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
|
||||||
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
|
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
|
||||||
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
|
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
|
||||||
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
|
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
|
||||||
|
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ class PythonEvalType(object):
|
||||||
SQL_WINDOW_AGG_PANDAS_UDF = 203
|
SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||||
SQL_SCALAR_PANDAS_ITER_UDF = 204
|
SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||||
SQL_MAP_PANDAS_ITER_UDF = 205
|
SQL_MAP_PANDAS_ITER_UDF = 205
|
||||||
|
SQL_COGROUPED_MAP_PANDAS_UDF = 206
|
||||||
|
|
||||||
|
|
||||||
def portable_hash(x):
|
def portable_hash(x):
|
||||||
|
|
|
@ -401,6 +401,32 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
|
||||||
return "ArrowStreamPandasUDFSerializer"
|
return "ArrowStreamPandasUDFSerializer"
|
||||||
|
|
||||||
|
|
||||||
|
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
|
||||||
|
|
||||||
|
def load_stream(self, stream):
|
||||||
|
"""
|
||||||
|
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
|
||||||
|
lists of pandas.Series.
|
||||||
|
"""
|
||||||
|
import pyarrow as pa
|
||||||
|
dataframes_in_group = None
|
||||||
|
|
||||||
|
while dataframes_in_group is None or dataframes_in_group > 0:
|
||||||
|
dataframes_in_group = read_int(stream)
|
||||||
|
|
||||||
|
if dataframes_in_group == 2:
|
||||||
|
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||||
|
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||||
|
yield (
|
||||||
|
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
|
||||||
|
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif dataframes_in_group != 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
|
||||||
|
|
||||||
|
|
||||||
class BatchedSerializer(Serializer):
|
class BatchedSerializer(Serializer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
98
python/pyspark/sql/cogroup.py
Normal file
98
python/pyspark/sql/cogroup.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from pyspark import since
|
||||||
|
from pyspark.rdd import PythonEvalType
|
||||||
|
from pyspark.sql.column import Column
|
||||||
|
from pyspark.sql.dataframe import DataFrame
|
||||||
|
|
||||||
|
|
||||||
|
class CoGroupedData(object):
|
||||||
|
"""
|
||||||
|
A logical grouping of two :class:`GroupedData`,
|
||||||
|
created by :func:`GroupedData.cogroup`.
|
||||||
|
|
||||||
|
.. note:: Experimental
|
||||||
|
|
||||||
|
.. versionadded:: 3.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gd1, gd2):
|
||||||
|
self._gd1 = gd1
|
||||||
|
self._gd2 = gd2
|
||||||
|
self.sql_ctx = gd1.sql_ctx
|
||||||
|
|
||||||
|
@since(3.0)
|
||||||
|
def apply(self, udf):
|
||||||
|
"""
|
||||||
|
Applies a function to each cogroup using a pandas udf and returns the result
|
||||||
|
as a `DataFrame`.
|
||||||
|
|
||||||
|
The user-defined function should take two `pandas.DataFrame` and return another
|
||||||
|
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
|
||||||
|
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
|
||||||
|
are combined as a :class:`DataFrame`.
|
||||||
|
|
||||||
|
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||||
|
returnType of the pandas udf.
|
||||||
|
|
||||||
|
.. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
|
||||||
|
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||||
|
and certain groups are too large to fit in memory.
|
||||||
|
|
||||||
|
.. note:: Experimental
|
||||||
|
|
||||||
|
:param udf: a cogrouped map user-defined function returned by
|
||||||
|
:func:`pyspark.sql.functions.pandas_udf`.
|
||||||
|
|
||||||
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||||
|
>>> df1 = spark.createDataFrame(
|
||||||
|
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||||
|
... ("time", "id", "v1"))
|
||||||
|
>>> df2 = spark.createDataFrame(
|
||||||
|
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||||
|
... ("time", "id", "v2"))
|
||||||
|
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
|
||||||
|
... def asof_join(l, r):
|
||||||
|
... return pd.merge_asof(l, r, on="time", by="id")
|
||||||
|
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
|
||||||
|
+--------+---+---+---+
|
||||||
|
| time| id| v1| v2|
|
||||||
|
+--------+---+---+---+
|
||||||
|
|20000101| 1|1.0| x|
|
||||||
|
|20000102| 1|3.0| x|
|
||||||
|
|20000101| 2|2.0| y|
|
||||||
|
|20000102| 2|4.0| y|
|
||||||
|
+--------+---+---+---+
|
||||||
|
|
||||||
|
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Columns are special because hasattr always return True
|
||||||
|
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||||
|
or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||||
|
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||||
|
"COGROUPED_MAP.")
|
||||||
|
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
|
||||||
|
udf_column = udf(*all_cols)
|
||||||
|
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
|
||||||
|
return DataFrame(jdf, self.sql_ctx)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_cols(gd):
|
||||||
|
df = gd._df
|
||||||
|
return [df[col] for col in df.columns]
|
|
@ -2814,6 +2814,8 @@ class PandasUDFType(object):
|
||||||
|
|
||||||
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
|
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
|
||||||
|
|
||||||
|
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
|
||||||
|
|
||||||
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
||||||
|
|
||||||
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
||||||
|
@ -3320,7 +3322,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
||||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
|
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||||
|
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
|
||||||
raise ValueError("Invalid functionType: "
|
raise ValueError("Invalid functionType: "
|
||||||
"functionType must be one the values from PandasUDFType")
|
"functionType must be one the values from PandasUDFType")
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
|
||||||
from pyspark.sql.column import Column, _to_seq
|
from pyspark.sql.column import Column, _to_seq
|
||||||
from pyspark.sql.dataframe import DataFrame
|
from pyspark.sql.dataframe import DataFrame
|
||||||
from pyspark.sql.types import *
|
from pyspark.sql.types import *
|
||||||
|
from pyspark.sql.cogroup import CoGroupedData
|
||||||
|
|
||||||
__all__ = ["GroupedData"]
|
__all__ = ["GroupedData"]
|
||||||
|
|
||||||
|
@ -218,6 +219,15 @@ class GroupedData(object):
|
||||||
jgd = self._jgd.pivot(pivot_col, values)
|
jgd = self._jgd.pivot(pivot_col, values)
|
||||||
return GroupedData(jgd, self._df)
|
return GroupedData(jgd, self._df)
|
||||||
|
|
||||||
|
@since(3.0)
|
||||||
|
def cogroup(self, other):
|
||||||
|
"""
|
||||||
|
Cogroups this group with another group so that we can run cogrouped operations.
|
||||||
|
|
||||||
|
See :class:`CoGroupedData` for the operations that can be run.
|
||||||
|
"""
|
||||||
|
return CoGroupedData(self, other)
|
||||||
|
|
||||||
@since(2.3)
|
@since(2.3)
|
||||||
def apply(self, udf):
|
def apply(self, udf):
|
||||||
"""
|
"""
|
||||||
|
@ -232,7 +242,7 @@ class GroupedData(object):
|
||||||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||||
returnType of the pandas udf.
|
returnType of the pandas udf.
|
||||||
|
|
||||||
.. note:: This function requires a full shuffle. all the data of a group will be loaded
|
.. note:: This function requires a full shuffle. All the data of a group will be loaded
|
||||||
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||||
and certain groups are too large to fit in memory.
|
and certain groups are too large to fit in memory.
|
||||||
|
|
||||||
|
|
280
python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Normal file
280
python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
|
||||||
|
from pyspark.sql.types import DoubleType, StructType, StructField
|
||||||
|
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||||
|
pandas_requirement_message, pyarrow_requirement_message
|
||||||
|
from pyspark.testing.utils import QuietTest
|
||||||
|
|
||||||
|
if have_pandas:
|
||||||
|
import pandas as pd
|
||||||
|
from pandas.util.testing import assert_frame_equal, assert_series_equal
|
||||||
|
|
||||||
|
if have_pyarrow:
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||||
|
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||||
|
"""
|
||||||
|
if sys.version < '3':
|
||||||
|
_check_column_type = False
|
||||||
|
else:
|
||||||
|
_check_column_type = True
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not have_pandas or not have_pyarrow,
|
||||||
|
pandas_requirement_message or pyarrow_requirement_message)
|
||||||
|
class CoGroupedMapPandasUDFTests(ReusedSQLTestCase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data1(self):
|
||||||
|
return self.spark.range(10).toDF('id') \
|
||||||
|
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
|
||||||
|
.withColumn("k", explode(col('ks')))\
|
||||||
|
.withColumn("v", col('k') * 10)\
|
||||||
|
.drop('ks')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data2(self):
|
||||||
|
return self.spark.range(10).toDF('id') \
|
||||||
|
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
|
||||||
|
.withColumn("k", explode(col('ks'))) \
|
||||||
|
.withColumn("v2", col('k') * 100) \
|
||||||
|
.drop('ks')
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
self._test_merge(self.data1, self.data2)
|
||||||
|
|
||||||
|
def test_left_group_empty(self):
|
||||||
|
left = self.data1.where(col("id") % 2 == 0)
|
||||||
|
self._test_merge(left, self.data2)
|
||||||
|
|
||||||
|
def test_right_group_empty(self):
|
||||||
|
right = self.data2.where(col("id") % 2 == 0)
|
||||||
|
self._test_merge(self.data1, right)
|
||||||
|
|
||||||
|
def test_different_schemas(self):
|
||||||
|
right = self.data2.withColumn('v3', lit('a'))
|
||||||
|
self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
|
||||||
|
|
||||||
|
def test_complex_group_by(self):
|
||||||
|
left = pd.DataFrame.from_dict({
|
||||||
|
'id': [1, 2, 3],
|
||||||
|
'k': [5, 6, 7],
|
||||||
|
'v': [9, 10, 11]
|
||||||
|
})
|
||||||
|
|
||||||
|
right = pd.DataFrame.from_dict({
|
||||||
|
'id': [11, 12, 13],
|
||||||
|
'k': [5, 6, 7],
|
||||||
|
'v2': [90, 100, 110]
|
||||||
|
})
|
||||||
|
|
||||||
|
left_gdf = self.spark\
|
||||||
|
.createDataFrame(left)\
|
||||||
|
.groupby(col('id') % 2 == 0)
|
||||||
|
|
||||||
|
right_gdf = self.spark \
|
||||||
|
.createDataFrame(right) \
|
||||||
|
.groupby(col('id') % 2 == 0)
|
||||||
|
|
||||||
|
@pandas_udf('k long, v long, v2 long', PandasUDFType.COGROUPED_MAP)
|
||||||
|
def merge_pandas(l, r):
|
||||||
|
return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
|
||||||
|
|
||||||
|
result = left_gdf \
|
||||||
|
.cogroup(right_gdf) \
|
||||||
|
.apply(merge_pandas) \
|
||||||
|
.sort(['k']) \
|
||||||
|
.toPandas()
|
||||||
|
|
||||||
|
expected = pd.DataFrame.from_dict({
|
||||||
|
'k': [5, 6, 7],
|
||||||
|
'v': [9, 10, 11],
|
||||||
|
'v2': [90, 100, 110]
|
||||||
|
})
|
||||||
|
|
||||||
|
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||||
|
|
||||||
|
def test_empty_group_by(self):
|
||||||
|
left = self.data1
|
||||||
|
right = self.data2
|
||||||
|
|
||||||
|
@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
|
||||||
|
def merge_pandas(l, r):
|
||||||
|
return pd.merge(l, r, on=['id', 'k'])
|
||||||
|
|
||||||
|
result = left.groupby().cogroup(right.groupby())\
|
||||||
|
.apply(merge_pandas) \
|
||||||
|
.sort(['id', 'k']) \
|
||||||
|
.toPandas()
|
||||||
|
|
||||||
|
left = left.toPandas()
|
||||||
|
right = right.toPandas()
|
||||||
|
|
||||||
|
expected = pd \
|
||||||
|
.merge(left, right, on=['id', 'k']) \
|
||||||
|
.sort_values(by=['id', 'k'])
|
||||||
|
|
||||||
|
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||||
|
|
||||||
|
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
|
||||||
|
df = self.spark.range(0, 10).toDF('v1')
|
||||||
|
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
|
||||||
|
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
|
||||||
|
|
||||||
|
result = df.groupby().cogroup(df.groupby())\
|
||||||
|
.apply(pandas_udf(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
|
||||||
|
'sum1 int, sum2 int',
|
||||||
|
PandasUDFType.COGROUPED_MAP)).collect()
|
||||||
|
|
||||||
|
self.assertEquals(result[0]['sum1'], 165)
|
||||||
|
self.assertEquals(result[0]['sum2'], 165)
|
||||||
|
|
||||||
|
def test_with_key_left(self):
|
||||||
|
self._test_with_key(self.data1, self.data1, isLeft=True)
|
||||||
|
|
||||||
|
def test_with_key_right(self):
|
||||||
|
self._test_with_key(self.data1, self.data1, isLeft=False)
|
||||||
|
|
||||||
|
def test_with_key_left_group_empty(self):
|
||||||
|
left = self.data1.where(col("id") % 2 == 0)
|
||||||
|
self._test_with_key(left, self.data1, isLeft=True)
|
||||||
|
|
||||||
|
def test_with_key_right_group_empty(self):
|
||||||
|
right = self.data1.where(col("id") % 2 == 0)
|
||||||
|
self._test_with_key(self.data1, right, isLeft=False)
|
||||||
|
|
||||||
|
def test_with_key_complex(self):
|
||||||
|
|
||||||
|
@pandas_udf('id long, k int, v int, key boolean', PandasUDFType.COGROUPED_MAP)
|
||||||
|
def left_assign_key(key, l, _):
|
||||||
|
return l.assign(key=key[0])
|
||||||
|
|
||||||
|
result = self.data1 \
|
||||||
|
.groupby(col('id') % 2 == 0)\
|
||||||
|
.cogroup(self.data2.groupby(col('id') % 2 == 0)) \
|
||||||
|
.apply(left_assign_key) \
|
||||||
|
.sort(['id', 'k']) \
|
||||||
|
.toPandas()
|
||||||
|
|
||||||
|
expected = self.data1.toPandas()
|
||||||
|
expected = expected.assign(key=expected.id % 2 == 0)
|
||||||
|
|
||||||
|
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||||
|
|
||||||
|
def test_wrong_return_type(self):
|
||||||
|
with QuietTest(self.sc):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
NotImplementedError,
|
||||||
|
'Invalid returnType.*cogrouped map Pandas UDF.*MapType'):
|
||||||
|
pandas_udf(
|
||||||
|
lambda l, r: l,
|
||||||
|
'id long, v map<int, int>',
|
||||||
|
PandasUDFType.COGROUPED_MAP)
|
||||||
|
|
||||||
|
def test_wrong_args(self):
|
||||||
|
# Test that we get a sensible exception invalid values passed to apply
|
||||||
|
left = self.data1
|
||||||
|
right = self.data2
|
||||||
|
with QuietTest(self.sc):
|
||||||
|
# Function rather than a udf
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id')).apply(lambda l, r: l)
|
||||||
|
|
||||||
|
# Udf missing return type
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id'))\
|
||||||
|
.apply(udf(lambda l, r: l, DoubleType()))
|
||||||
|
|
||||||
|
# Pass in expression rather than udf
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id')).apply(left.v + 1)
|
||||||
|
|
||||||
|
# Zero arg function
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id'))\
|
||||||
|
.apply(pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
|
||||||
|
|
||||||
|
# Udf without PandasUDFType
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id'))\
|
||||||
|
.apply(pandas_udf(lambda x, y: x, DoubleType()))
|
||||||
|
|
||||||
|
# Udf with incorrect PandasUDFType
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*COGROUPED_MAP'):
|
||||||
|
left.groupby('id').cogroup(right.groupby('id'))\
|
||||||
|
.apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _test_with_key(left, right, isLeft):
|
||||||
|
|
||||||
|
@pandas_udf('id long, k int, v int, key long', PandasUDFType.COGROUPED_MAP)
|
||||||
|
def right_assign_key(key, l, r):
|
||||||
|
return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
|
||||||
|
|
||||||
|
result = left \
|
||||||
|
.groupby('id') \
|
||||||
|
.cogroup(right.groupby('id')) \
|
||||||
|
.apply(right_assign_key) \
|
||||||
|
.toPandas()
|
||||||
|
|
||||||
|
expected = left.toPandas() if isLeft else right.toPandas()
|
||||||
|
expected = expected.assign(key=expected.id)
|
||||||
|
|
||||||
|
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'):
|
||||||
|
|
||||||
|
@pandas_udf(output_schema, PandasUDFType.COGROUPED_MAP)
|
||||||
|
def merge_pandas(l, r):
|
||||||
|
return pd.merge(l, r, on=['id', 'k'])
|
||||||
|
|
||||||
|
result = left \
|
||||||
|
.groupby('id') \
|
||||||
|
.cogroup(right.groupby('id')) \
|
||||||
|
.apply(merge_pandas)\
|
||||||
|
.sort(['id', 'k']) \
|
||||||
|
.toPandas()
|
||||||
|
|
||||||
|
left = left.toPandas()
|
||||||
|
right = right.toPandas()
|
||||||
|
|
||||||
|
expected = pd \
|
||||||
|
.merge(left, right, on=['id', 'k']) \
|
||||||
|
.sort_values(by=['id', 'k'])
|
||||||
|
|
||||||
|
assert_frame_equal(expected, result, check_column_type=_check_column_type)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pyspark.sql.tests.test_pandas_udf_cogrouped_map import *
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xmlrunner
|
||||||
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||||
|
except ImportError:
|
||||||
|
testRunner = None
|
||||||
|
unittest.main(testRunner=testRunner, verbosity=2)
|
|
@ -42,6 +42,7 @@ def _create_udf(f, returnType, evalType):
|
||||||
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||||
|
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
|
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
|
||||||
|
|
||||||
|
@ -65,6 +66,13 @@ def _create_udf(f, returnType, evalType):
|
||||||
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
||||||
"must take either one argument (data) or two arguments (key, data).")
|
"must take either one argument (data) or two arguments (key, data).")
|
||||||
|
|
||||||
|
if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF \
|
||||||
|
and len(argspec.args) not in (2, 3):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid function: pandas_udfs with function type COGROUPED_MAP "
|
||||||
|
"must take either two arguments (left, right) "
|
||||||
|
"or three arguments (key, left, right).")
|
||||||
|
|
||||||
# Set the name of the UserDefinedFunction object to be the name of function f
|
# Set the name of the UserDefinedFunction object to be the name of function f
|
||||||
udf_obj = UserDefinedFunction(
|
udf_obj = UserDefinedFunction(
|
||||||
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
||||||
|
@ -147,6 +155,17 @@ class UserDefinedFunction(object):
|
||||||
else:
|
else:
|
||||||
raise TypeError("Invalid returnType for map iterator Pandas "
|
raise TypeError("Invalid returnType for map iterator Pandas "
|
||||||
"UDFs: returnType must be a StructType.")
|
"UDFs: returnType must be a StructType.")
|
||||||
|
elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||||
|
if isinstance(self._returnType_placeholder, StructType):
|
||||||
|
try:
|
||||||
|
to_arrow_type(self._returnType_placeholder)
|
||||||
|
except TypeError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Invalid returnType with cogrouped map Pandas UDFs: "
|
||||||
|
"%s is not supported" % str(self._returnType_placeholder))
|
||||||
|
else:
|
||||||
|
raise TypeError("Invalid returnType for cogrouped map Pandas "
|
||||||
|
"UDFs: returnType must be a StructType.")
|
||||||
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
||||||
try:
|
try:
|
||||||
# StructType is not yet allowed as a return type, explicitly check here to fail fast
|
# StructType is not yet allowed as a return type, explicitly check here to fail fast
|
||||||
|
|
|
@ -39,7 +39,7 @@ from pyspark.resourceinformation import ResourceInformation
|
||||||
from pyspark.rdd import PythonEvalType
|
from pyspark.rdd import PythonEvalType
|
||||||
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
|
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
|
||||||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
||||||
BatchedSerializer, ArrowStreamPandasUDFSerializer
|
BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
|
||||||
from pyspark.sql.types import to_arrow_type, StructType
|
from pyspark.sql.types import to_arrow_type, StructType
|
||||||
from pyspark.util import _get_argspec, fail_on_stopiteration
|
from pyspark.util import _get_argspec, fail_on_stopiteration
|
||||||
from pyspark import shuffle
|
from pyspark import shuffle
|
||||||
|
@ -121,6 +121,33 @@ def wrap_pandas_iter_udf(f, return_type):
|
||||||
map(verify_result_type, f(*iterator)))
|
map(verify_result_type, f(*iterator)))
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
|
||||||
|
|
||||||
|
def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
left_df = pd.concat(left_value_series, axis=1)
|
||||||
|
right_df = pd.concat(right_value_series, axis=1)
|
||||||
|
|
||||||
|
if len(argspec.args) == 2:
|
||||||
|
result = f(left_df, right_df)
|
||||||
|
elif len(argspec.args) == 3:
|
||||||
|
key_series = left_key_series if not left_df.empty else right_key_series
|
||||||
|
key = tuple(s[0] for s in key_series)
|
||||||
|
result = f(key, left_df, right_df)
|
||||||
|
if not isinstance(result, pd.DataFrame):
|
||||||
|
raise TypeError("Return type of the user-defined function should be "
|
||||||
|
"pandas.DataFrame, but is {}".format(type(result)))
|
||||||
|
if not len(result.columns) == len(return_type):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Number of columns of the returned pandas.DataFrame "
|
||||||
|
"doesn't match specified schema. "
|
||||||
|
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))]
|
||||||
|
|
||||||
|
|
||||||
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
|
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
|
||||||
|
|
||||||
def wrapped(key_series, value_series):
|
def wrapped(key_series, value_series):
|
||||||
|
@ -244,6 +271,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
|
||||||
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||||
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
||||||
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
|
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
|
||||||
|
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||||
|
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
||||||
|
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec)
|
||||||
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
||||||
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
|
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
|
||||||
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
|
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
|
||||||
|
@ -258,6 +288,7 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
runner_conf = {}
|
runner_conf = {}
|
||||||
|
|
||||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||||
|
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||||
|
@ -280,6 +311,9 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
|
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
|
||||||
.lower() == "true"
|
.lower() == "true"
|
||||||
|
|
||||||
|
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||||
|
ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
|
||||||
|
else:
|
||||||
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
|
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
|
||||||
# pandas Series. See SPARK-27240.
|
# pandas Series. See SPARK-27240.
|
||||||
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
|
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
|
||||||
|
@ -343,6 +377,32 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
# profiling is not supported for UDF
|
# profiling is not supported for UDF
|
||||||
return func, None, ser, ser
|
return func, None, ser, ser
|
||||||
|
|
||||||
|
def extract_key_value_indexes(grouped_arg_offsets):
|
||||||
|
"""
|
||||||
|
Helper function to extract the key and value indexes from arg_offsets for the grouped and
|
||||||
|
cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.
|
||||||
|
|
||||||
|
:param grouped_arg_offsets: List containing the key and value indexes of columns of the
|
||||||
|
DataFrames to be passed to the udf. It consists of n repeating groups where n is the
|
||||||
|
number of DataFrames. Each group has the following format:
|
||||||
|
group[0]: length of group
|
||||||
|
group[1]: length of key indexes
|
||||||
|
group[2.. group[1] +2]: key attributes
|
||||||
|
group[group[1] +3 group[0]]: value attributes
|
||||||
|
"""
|
||||||
|
parsed = []
|
||||||
|
idx = 0
|
||||||
|
while idx < len(grouped_arg_offsets):
|
||||||
|
offsets_len = grouped_arg_offsets[idx]
|
||||||
|
idx += 1
|
||||||
|
offsets = grouped_arg_offsets[idx: idx + offsets_len]
|
||||||
|
split_index = offsets[0] + 1
|
||||||
|
offset_keys = offsets[1: split_index]
|
||||||
|
offset_values = offsets[split_index:]
|
||||||
|
parsed.append([offset_keys, offset_values])
|
||||||
|
idx += offsets_len
|
||||||
|
return parsed
|
||||||
|
|
||||||
udfs = {}
|
udfs = {}
|
||||||
call_udf = []
|
call_udf = []
|
||||||
mapper_str = ""
|
mapper_str = ""
|
||||||
|
@ -359,10 +419,24 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
arg_offsets, udf = read_single_udf(
|
arg_offsets, udf = read_single_udf(
|
||||||
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
||||||
udfs['f'] = udf
|
udfs['f'] = udf
|
||||||
split_offset = arg_offsets[0] + 1
|
parsed_offsets = extract_key_value_indexes(arg_offsets)
|
||||||
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
|
keys = ["a[%d]" % (o,) for o in parsed_offsets[0][0]]
|
||||||
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
|
vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]]
|
||||||
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
|
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys), ", ".join(vals))
|
||||||
|
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
||||||
|
# We assume there is only one UDF here because cogrouped map doesn't
|
||||||
|
# support combining multiple UDFs.
|
||||||
|
assert num_udfs == 1
|
||||||
|
arg_offsets, udf = read_single_udf(
|
||||||
|
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
||||||
|
udfs['f'] = udf
|
||||||
|
parsed_offsets = extract_key_value_indexes(arg_offsets)
|
||||||
|
df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]]
|
||||||
|
df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]]
|
||||||
|
df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]]
|
||||||
|
df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]]
|
||||||
|
mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % (
|
||||||
|
", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys), ", ".join(df2_vals))
|
||||||
else:
|
else:
|
||||||
# Create function like this:
|
# Create function like this:
|
||||||
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
|
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
|
||||||
|
|
|
@ -1191,6 +1191,12 @@ class Analyzer(
|
||||||
// To resolve duplicate expression IDs for Join and Intersect
|
// To resolve duplicate expression IDs for Join and Intersect
|
||||||
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
|
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
|
||||||
j.copy(right = dedupRight(left, right))
|
j.copy(right = dedupRight(left, right))
|
||||||
|
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
|
||||||
|
val leftRes = leftAttributes
|
||||||
|
.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
|
||||||
|
val rightRes = rightAttributes
|
||||||
|
.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
|
||||||
|
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
|
||||||
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
|
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
|
||||||
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
|
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
|
||||||
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
|
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
|
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
|
||||||
* This is used by DataFrame.groupby().apply().
|
* This is used by DataFrame.groupby().apply().
|
||||||
*/
|
*/
|
||||||
case class FlatMapGroupsInPandas(
|
case class FlatMapGroupsInPandas(
|
||||||
|
@ -40,7 +40,7 @@ case class FlatMapGroupsInPandas(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map partitions using an udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
|
* Map partitions using a udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
|
||||||
* This is used by DataFrame.mapInPandas()
|
* This is used by DataFrame.mapInPandas()
|
||||||
*/
|
*/
|
||||||
case class MapInPandas(
|
case class MapInPandas(
|
||||||
|
@ -51,6 +51,21 @@ case class MapInPandas(
|
||||||
override val producedAttributes = AttributeSet(output)
|
override val producedAttributes = AttributeSet(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe
|
||||||
|
* This is used by DataFrame.groupby().cogroup().apply().
|
||||||
|
*/
|
||||||
|
case class FlatMapCoGroupsInPandas(
|
||||||
|
leftAttributes: Seq[Attribute],
|
||||||
|
rightAttributes: Seq[Attribute],
|
||||||
|
functionExpr: Expression,
|
||||||
|
output: Seq[Attribute],
|
||||||
|
left: LogicalPlan,
|
||||||
|
right: LogicalPlan) extends BinaryNode {
|
||||||
|
|
||||||
|
override val producedAttributes = AttributeSet(output)
|
||||||
|
}
|
||||||
|
|
||||||
trait BaseEvalPython extends UnaryNode {
|
trait BaseEvalPython extends UnaryNode {
|
||||||
|
|
||||||
def udfs: Seq[PythonUDF]
|
def udfs: Seq[PythonUDF]
|
||||||
|
|
|
@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
|
||||||
*/
|
*/
|
||||||
@Stable
|
@Stable
|
||||||
class RelationalGroupedDataset protected[sql](
|
class RelationalGroupedDataset protected[sql](
|
||||||
df: DataFrame,
|
val df: DataFrame,
|
||||||
groupingExprs: Seq[Expression],
|
val groupingExprs: Seq[Expression],
|
||||||
groupType: RelationalGroupedDataset.GroupType) {
|
groupType: RelationalGroupedDataset.GroupType) {
|
||||||
|
|
||||||
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
|
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
|
||||||
|
@ -523,6 +523,48 @@ class RelationalGroupedDataset protected[sql](
|
||||||
Dataset.ofRows(df.sparkSession, plan)
|
Dataset.ofRows(df.sparkSession, plan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies a vectorized python user-defined function to each cogrouped data.
|
||||||
|
* The user-defined function defines a transformation:
|
||||||
|
* `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`.
|
||||||
|
* For each group in the cogrouped data, all elements in the group are passed as a
|
||||||
|
* `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]].
|
||||||
|
*
|
||||||
|
* This function uses Apache Arrow as serialization format between Java executors and Python
|
||||||
|
* workers.
|
||||||
|
*/
|
||||||
|
private[sql] def flatMapCoGroupsInPandas(
|
||||||
|
r: RelationalGroupedDataset,
|
||||||
|
expr: PythonUDF): DataFrame = {
|
||||||
|
require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||||
|
"Must pass a cogrouped map udf")
|
||||||
|
require(expr.dataType.isInstanceOf[StructType],
|
||||||
|
s"The returnType of the udf must be a ${StructType.simpleString}")
|
||||||
|
|
||||||
|
val leftGroupingNamedExpressions = groupingExprs.map {
|
||||||
|
case ne: NamedExpression => ne
|
||||||
|
case other => Alias(other, other.toString)()
|
||||||
|
}
|
||||||
|
|
||||||
|
val rightGroupingNamedExpressions = r.groupingExprs.map {
|
||||||
|
case ne: NamedExpression => ne
|
||||||
|
case other => Alias(other, other.toString)()
|
||||||
|
}
|
||||||
|
|
||||||
|
val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
|
||||||
|
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
|
||||||
|
|
||||||
|
val leftChild = df.logicalPlan
|
||||||
|
val rightChild = r.df.logicalPlan
|
||||||
|
|
||||||
|
val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
|
||||||
|
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
|
||||||
|
|
||||||
|
val output = expr.dataType.asInstanceOf[StructType].toAttributes
|
||||||
|
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
|
||||||
|
Dataset.ofRows(df.sparkSession, plan)
|
||||||
|
}
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
val builder = new StringBuilder
|
val builder = new StringBuilder
|
||||||
builder.append("RelationalGroupedDataset: [grouping expressions: [")
|
builder.append("RelationalGroupedDataset: [grouping expressions: [")
|
||||||
|
|
|
@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
||||||
f, p, b, is, ot, planLater(child)) :: Nil
|
f, p, b, is, ot, planLater(child)) :: Nil
|
||||||
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
|
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
|
||||||
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
|
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
|
||||||
|
case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) =>
|
||||||
|
execution.python.FlatMapCoGroupsInPandasExec(
|
||||||
|
leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil
|
||||||
case logical.MapInPandas(func, output, child) =>
|
case logical.MapInPandas(func, output, child) =>
|
||||||
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
|
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
|
||||||
case logical.MapElements(f, _, _, objAttr, child) =>
|
case logical.MapElements(f, _, _, objAttr, child) =>
|
||||||
|
|
|
@ -19,12 +19,9 @@ package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
import java.io._
|
import java.io._
|
||||||
import java.net._
|
import java.net._
|
||||||
import java.util.concurrent.atomic.AtomicBoolean
|
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import org.apache.arrow.vector.VectorSchemaRoot
|
import org.apache.arrow.vector.VectorSchemaRoot
|
||||||
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
|
import org.apache.arrow.vector.ipc.ArrowStreamWriter
|
||||||
|
|
||||||
import org.apache.spark._
|
import org.apache.spark._
|
||||||
import org.apache.spark.api.python._
|
import org.apache.spark.api.python._
|
||||||
|
@ -33,7 +30,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.sql.util.ArrowUtils
|
import org.apache.spark.sql.util.ArrowUtils
|
||||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -46,7 +42,7 @@ class ArrowPythonRunner(
|
||||||
schema: StructType,
|
schema: StructType,
|
||||||
timeZoneId: String,
|
timeZoneId: String,
|
||||||
conf: Map[String, String])
|
conf: Map[String, String])
|
||||||
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
|
extends BaseArrowPythonRunner[Iterator[InternalRow]](
|
||||||
funcs, evalType, argOffsets) {
|
funcs, evalType, argOffsets) {
|
||||||
|
|
||||||
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
|
override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
|
||||||
|
@ -119,72 +115,4 @@ class ArrowPythonRunner(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override def newReaderIterator(
|
|
||||||
stream: DataInputStream,
|
|
||||||
writerThread: WriterThread,
|
|
||||||
startTime: Long,
|
|
||||||
env: SparkEnv,
|
|
||||||
worker: Socket,
|
|
||||||
releasedOrClosed: AtomicBoolean,
|
|
||||||
context: TaskContext): Iterator[ColumnarBatch] = {
|
|
||||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
|
||||||
|
|
||||||
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
|
||||||
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
|
||||||
|
|
||||||
private var reader: ArrowStreamReader = _
|
|
||||||
private var root: VectorSchemaRoot = _
|
|
||||||
private var schema: StructType = _
|
|
||||||
private var vectors: Array[ColumnVector] = _
|
|
||||||
|
|
||||||
context.addTaskCompletionListener[Unit] { _ =>
|
|
||||||
if (reader != null) {
|
|
||||||
reader.close(false)
|
|
||||||
}
|
|
||||||
allocator.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
private var batchLoaded = true
|
|
||||||
|
|
||||||
protected override def read(): ColumnarBatch = {
|
|
||||||
if (writerThread.exception.isDefined) {
|
|
||||||
throw writerThread.exception.get
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
if (reader != null && batchLoaded) {
|
|
||||||
batchLoaded = reader.loadNextBatch()
|
|
||||||
if (batchLoaded) {
|
|
||||||
val batch = new ColumnarBatch(vectors)
|
|
||||||
batch.setNumRows(root.getRowCount)
|
|
||||||
batch
|
|
||||||
} else {
|
|
||||||
reader.close(false)
|
|
||||||
allocator.close()
|
|
||||||
// Reach end of stream. Call `read()` again to read control data.
|
|
||||||
read()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stream.readInt() match {
|
|
||||||
case SpecialLengths.START_ARROW_STREAM =>
|
|
||||||
reader = new ArrowStreamReader(stream, allocator)
|
|
||||||
root = reader.getVectorSchemaRoot()
|
|
||||||
schema = ArrowUtils.fromArrowSchema(root.getSchema())
|
|
||||||
vectors = root.getFieldVectors().asScala.map { vector =>
|
|
||||||
new ArrowColumnVector(vector)
|
|
||||||
}.toArray[ColumnVector]
|
|
||||||
read()
|
|
||||||
case SpecialLengths.TIMING_DATA =>
|
|
||||||
handleTimingData()
|
|
||||||
read()
|
|
||||||
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
|
|
||||||
throw handlePythonException()
|
|
||||||
case SpecialLengths.END_OF_DATA_SECTION =>
|
|
||||||
handleEndOfDataSection()
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch handleException
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
|
import java.io._
|
||||||
|
import java.net._
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import org.apache.arrow.vector.VectorSchemaRoot
|
||||||
|
import org.apache.arrow.vector.ipc.ArrowStreamReader
|
||||||
|
|
||||||
|
import org.apache.spark._
|
||||||
|
import org.apache.spark.api.python._
|
||||||
|
import org.apache.spark.sql.types.StructType
|
||||||
|
import org.apache.spark.sql.util.ArrowUtils
|
||||||
|
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Common functionality for a udf runner that exchanges data with Python worker via Arrow stream.
|
||||||
|
*/
|
||||||
|
abstract class BaseArrowPythonRunner[T](
|
||||||
|
funcs: Seq[ChainedPythonFunctions],
|
||||||
|
evalType: Int,
|
||||||
|
argOffsets: Array[Array[Int]])
|
||||||
|
extends BasePythonRunner[T, ColumnarBatch](funcs, evalType, argOffsets) {
|
||||||
|
|
||||||
|
protected override def newReaderIterator(
|
||||||
|
stream: DataInputStream,
|
||||||
|
writerThread: WriterThread,
|
||||||
|
startTime: Long,
|
||||||
|
env: SparkEnv,
|
||||||
|
worker: Socket,
|
||||||
|
releasedOrClosed: AtomicBoolean,
|
||||||
|
context: TaskContext): Iterator[ColumnarBatch] = {
|
||||||
|
|
||||||
|
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||||
|
|
||||||
|
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||||
|
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
||||||
|
|
||||||
|
private var reader: ArrowStreamReader = _
|
||||||
|
private var root: VectorSchemaRoot = _
|
||||||
|
private var schema: StructType = _
|
||||||
|
private var vectors: Array[ColumnVector] = _
|
||||||
|
|
||||||
|
context.addTaskCompletionListener[Unit] { _ =>
|
||||||
|
if (reader != null) {
|
||||||
|
reader.close(false)
|
||||||
|
}
|
||||||
|
allocator.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
private var batchLoaded = true
|
||||||
|
|
||||||
|
protected override def read(): ColumnarBatch = {
|
||||||
|
if (writerThread.exception.isDefined) {
|
||||||
|
throw writerThread.exception.get
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
if (reader != null && batchLoaded) {
|
||||||
|
batchLoaded = reader.loadNextBatch()
|
||||||
|
if (batchLoaded) {
|
||||||
|
val batch = new ColumnarBatch(vectors)
|
||||||
|
batch.setNumRows(root.getRowCount)
|
||||||
|
batch
|
||||||
|
} else {
|
||||||
|
reader.close(false)
|
||||||
|
allocator.close()
|
||||||
|
// Reach end of stream. Call `read()` again to read control data.
|
||||||
|
read()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stream.readInt() match {
|
||||||
|
case SpecialLengths.START_ARROW_STREAM =>
|
||||||
|
reader = new ArrowStreamReader(stream, allocator)
|
||||||
|
root = reader.getVectorSchemaRoot()
|
||||||
|
schema = ArrowUtils.fromArrowSchema(root.getSchema())
|
||||||
|
vectors = root.getFieldVectors().asScala.map { vector =>
|
||||||
|
new ArrowColumnVector(vector)
|
||||||
|
}.toArray[ColumnVector]
|
||||||
|
read()
|
||||||
|
case SpecialLengths.TIMING_DATA =>
|
||||||
|
handleTimingData()
|
||||||
|
read()
|
||||||
|
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
|
||||||
|
throw handlePythonException()
|
||||||
|
case SpecialLengths.END_OF_DATA_SECTION =>
|
||||||
|
handleEndOfDataSection()
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch handleException
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,137 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
import org.apache.spark.TaskContext
|
||||||
|
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection}
|
||||||
|
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
|
||||||
|
import org.apache.spark.sql.util.ArrowUtils
|
||||||
|
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base functionality for plans which execute grouped python udfs.
|
||||||
|
*/
|
||||||
|
abstract class BasePandasGroupExec(
|
||||||
|
func: Expression,
|
||||||
|
output: Seq[Attribute])
|
||||||
|
extends SparkPlan {
|
||||||
|
|
||||||
|
protected val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||||
|
|
||||||
|
protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
|
||||||
|
|
||||||
|
protected val pandasFunction = func.asInstanceOf[PythonUDF].func
|
||||||
|
|
||||||
|
protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||||
|
|
||||||
|
override def producedAttributes: AttributeSet = AttributeSet(output)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* passes the data to the python runner and coverts the resulting
|
||||||
|
* columnarbatch into internal rows.
|
||||||
|
*/
|
||||||
|
protected def executePython[T](
|
||||||
|
data: Iterator[T],
|
||||||
|
runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = {
|
||||||
|
|
||||||
|
val context = TaskContext.get()
|
||||||
|
val columnarBatchIter = runner.compute(data, context.partitionId(), context)
|
||||||
|
val unsafeProj = UnsafeProjection.create(output, output)
|
||||||
|
|
||||||
|
columnarBatchIter.flatMap { batch =>
|
||||||
|
// UDF returns a StructType column in ColumnarBatch, select the children here
|
||||||
|
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
|
||||||
|
val outputVectors = output.indices.map(structVector.getChild)
|
||||||
|
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
|
||||||
|
flattenedBatch.setNumRows(batch.numRows())
|
||||||
|
flattenedBatch.rowIterator.asScala
|
||||||
|
}.map(unsafeProj)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* groups according to grouping attributes and then projects into the deduplicated schema
|
||||||
|
*/
|
||||||
|
protected def groupAndProject(
|
||||||
|
input: Iterator[InternalRow],
|
||||||
|
groupingAttributes: Seq[Attribute],
|
||||||
|
inputSchema: Seq[Attribute],
|
||||||
|
dedupSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
|
||||||
|
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
|
||||||
|
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
|
||||||
|
groupedIter.map {
|
||||||
|
case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a the deduplicated attributes of the spark plan and the arg offsets of the
|
||||||
|
* keys and values.
|
||||||
|
*
|
||||||
|
* The deduplicated attributes are needed because the spark plan may contain an attribute
|
||||||
|
* twice; once in the key and once in the value. For any such attribute we need to
|
||||||
|
* deduplicate.
|
||||||
|
*
|
||||||
|
* The arg offsets are used to distinguish grouping grouping attributes and data attributes
|
||||||
|
* as following:
|
||||||
|
*
|
||||||
|
* argOffsets[0] is the length of the argOffsets array
|
||||||
|
*
|
||||||
|
* argOffsets[1] is the length of grouping attribute
|
||||||
|
* argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping attributes
|
||||||
|
*
|
||||||
|
* argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
|
||||||
|
*/
|
||||||
|
protected def resolveArgOffsets(
|
||||||
|
child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = {
|
||||||
|
|
||||||
|
val dataAttributes = child.output.drop(groupingAttributes.length)
|
||||||
|
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
||||||
|
dataAttributes.indexWhere(attribute.semanticEquals)
|
||||||
|
}
|
||||||
|
|
||||||
|
val groupingArgOffsets = new ArrayBuffer[Int]
|
||||||
|
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
||||||
|
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
||||||
|
|
||||||
|
groupingAttributes.zip(groupingIndicesInData).foreach {
|
||||||
|
case (attribute, index) =>
|
||||||
|
if (index == -1) {
|
||||||
|
groupingArgOffsets += nonDupGroupingAttributes.length
|
||||||
|
nonDupGroupingAttributes += attribute
|
||||||
|
} else {
|
||||||
|
groupingArgOffsets += index + nonDupGroupingSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val dataArgOffsets = nonDupGroupingAttributes.length until
|
||||||
|
(nonDupGroupingAttributes.length + dataAttributes.length)
|
||||||
|
|
||||||
|
val argOffsetsLength = groupingAttributes.length + dataArgOffsets.length + 1
|
||||||
|
val argOffsets = Array(argOffsetsLength,
|
||||||
|
groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets
|
||||||
|
|
||||||
|
// Attributes after deduplication
|
||||||
|
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
||||||
|
(dedupAttributes, argOffsets)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
|
import java.io._
|
||||||
|
import java.net._
|
||||||
|
|
||||||
|
import org.apache.arrow.vector.VectorSchemaRoot
|
||||||
|
import org.apache.arrow.vector.ipc.ArrowStreamWriter
|
||||||
|
|
||||||
|
import org.apache.spark._
|
||||||
|
import org.apache.spark.api.python._
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.execution.arrow.ArrowWriter
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.util.ArrowUtils
|
||||||
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Python UDF Runner for cogrouped udfs. Although the data is exchanged with the python
|
||||||
|
* worker via arrow, we cannot use `ArrowPythonRunner` as we need to send more than one
|
||||||
|
* dataframe.
|
||||||
|
*/
|
||||||
|
class CogroupedArrowPythonRunner(
|
||||||
|
funcs: Seq[ChainedPythonFunctions],
|
||||||
|
evalType: Int,
|
||||||
|
argOffsets: Array[Array[Int]],
|
||||||
|
leftSchema: StructType,
|
||||||
|
rightSchema: StructType,
|
||||||
|
timeZoneId: String,
|
||||||
|
conf: Map[String, String])
|
||||||
|
extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])](
|
||||||
|
funcs, evalType, argOffsets) {
|
||||||
|
|
||||||
|
protected def newWriterThread(
|
||||||
|
env: SparkEnv,
|
||||||
|
worker: Socket,
|
||||||
|
inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
|
||||||
|
partitionIndex: Int,
|
||||||
|
context: TaskContext): WriterThread = {
|
||||||
|
|
||||||
|
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
|
||||||
|
|
||||||
|
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
|
||||||
|
|
||||||
|
// Write config for the worker as a number of key -> value pairs of strings
|
||||||
|
dataOut.writeInt(conf.size)
|
||||||
|
for ((k, v) <- conf) {
|
||||||
|
PythonRDD.writeUTF(k, dataOut)
|
||||||
|
PythonRDD.writeUTF(v, dataOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
|
||||||
|
// For each we first send the number of dataframes in each group then send
|
||||||
|
// first df, then send second df. End of data is marked by sending 0.
|
||||||
|
while (inputIterator.hasNext) {
|
||||||
|
dataOut.writeInt(2)
|
||||||
|
val (nextLeft, nextRight) = inputIterator.next()
|
||||||
|
writeGroup(nextLeft, leftSchema, dataOut, "left")
|
||||||
|
writeGroup(nextRight, rightSchema, dataOut, "right")
|
||||||
|
}
|
||||||
|
dataOut.writeInt(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
def writeGroup(
|
||||||
|
group: Iterator[InternalRow],
|
||||||
|
schema: StructType,
|
||||||
|
dataOut: DataOutputStream,
|
||||||
|
name: String) = {
|
||||||
|
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||||
|
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||||
|
s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
|
||||||
|
val root = VectorSchemaRoot.create(arrowSchema, allocator)
|
||||||
|
|
||||||
|
Utils.tryWithSafeFinally {
|
||||||
|
val writer = new ArrowStreamWriter(root, null, dataOut)
|
||||||
|
val arrowWriter = ArrowWriter.create(root)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
while (group.hasNext) {
|
||||||
|
arrowWriter.write(group.next())
|
||||||
|
}
|
||||||
|
arrowWriter.finish()
|
||||||
|
writer.writeBatch()
|
||||||
|
writer.end()
|
||||||
|
}{
|
||||||
|
root.close()
|
||||||
|
allocator.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
|
import org.apache.spark.api.python.PythonEvalType
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
||||||
|
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
|
||||||
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
|
||||||
|
*
|
||||||
|
* The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the
|
||||||
|
* Python worker via Arrow. As each side of the cogroup may have a different schema we send every
|
||||||
|
* group in its own Arrow stream.
|
||||||
|
* The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the
|
||||||
|
* user-defined function, and passes the resulting `pandas.DataFrame`
|
||||||
|
* as an Arrow record batch. Finally, each record batch is turned to
|
||||||
|
* Iterator[InternalRow] using ColumnarBatch.
|
||||||
|
*
|
||||||
|
* Note on memory usage:
|
||||||
|
* Both the Python worker and the Java executor need to have enough memory to
|
||||||
|
* hold the largest cogroup. The memory on the Java side is used to construct the
|
||||||
|
* record batches (off heap memory). The memory on the Python side is used for
|
||||||
|
* holding the `pandas.DataFrame`. It's possible to further split one group into
|
||||||
|
* multiple record batches to reduce the memory footprint on the Java side, this
|
||||||
|
* is left as future work.
|
||||||
|
*/
|
||||||
|
case class FlatMapCoGroupsInPandasExec(
|
||||||
|
leftGroup: Seq[Attribute],
|
||||||
|
rightGroup: Seq[Attribute],
|
||||||
|
func: Expression,
|
||||||
|
output: Seq[Attribute],
|
||||||
|
left: SparkPlan,
|
||||||
|
right: SparkPlan)
|
||||||
|
extends BasePandasGroupExec(func, output) with BinaryExecNode {
|
||||||
|
|
||||||
|
override def outputPartitioning: Partitioning = left.outputPartitioning
|
||||||
|
|
||||||
|
override def requiredChildDistribution: Seq[Distribution] = {
|
||||||
|
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
|
||||||
|
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
|
||||||
|
leftDist :: rightDist :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
|
||||||
|
leftGroup
|
||||||
|
.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
|
|
||||||
|
val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
|
||||||
|
val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
|
||||||
|
|
||||||
|
// Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty
|
||||||
|
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
|
||||||
|
if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
|
||||||
|
|
||||||
|
val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup)
|
||||||
|
val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup)
|
||||||
|
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
|
||||||
|
.map { case (_, l, r) => (l, r) }
|
||||||
|
|
||||||
|
val runner = new CogroupedArrowPythonRunner(
|
||||||
|
chainedFunc,
|
||||||
|
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
|
||||||
|
Array(leftArgOffsets ++ rightArgOffsets),
|
||||||
|
StructType.fromAttributes(leftDedup),
|
||||||
|
StructType.fromAttributes(rightDedup),
|
||||||
|
sessionLocalTimeZone,
|
||||||
|
pythonRunnerConf)
|
||||||
|
|
||||||
|
executePython(data, runner)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,19 +17,14 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.python
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import org.apache.spark.api.python.PythonEvalType
|
||||||
import scala.collection.mutable.ArrayBuffer
|
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
|
||||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
||||||
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
|
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
import org.apache.spark.sql.util.ArrowUtils
|
|
||||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
|
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
|
||||||
|
@ -53,14 +48,10 @@ case class FlatMapGroupsInPandasExec(
|
||||||
func: Expression,
|
func: Expression,
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
child: SparkPlan)
|
child: SparkPlan)
|
||||||
extends UnaryExecNode {
|
extends BasePandasGroupExec(func, output) with UnaryExecNode {
|
||||||
|
|
||||||
private val pandasFunction = func.asInstanceOf[PythonUDF].func
|
|
||||||
|
|
||||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||||
|
|
||||||
override def producedAttributes: AttributeSet = AttributeSet(output)
|
|
||||||
|
|
||||||
override def requiredChildDistribution: Seq[Distribution] = {
|
override def requiredChildDistribution: Seq[Distribution] = {
|
||||||
if (groupingAttributes.isEmpty) {
|
if (groupingAttributes.isEmpty) {
|
||||||
AllTuples :: Nil
|
AllTuples :: Nil
|
||||||
|
@ -75,88 +66,23 @@ case class FlatMapGroupsInPandasExec(
|
||||||
override protected def doExecute(): RDD[InternalRow] = {
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
val inputRDD = child.execute()
|
val inputRDD = child.execute()
|
||||||
|
|
||||||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes)
|
||||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
|
||||||
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
|
|
||||||
|
|
||||||
// Deduplicate the grouping attributes.
|
|
||||||
// If a grouping attribute also appears in data attributes, then we don't need to send the
|
|
||||||
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
|
|
||||||
// then we need to send this grouping attribute to python worker.
|
|
||||||
//
|
|
||||||
// We use argOffsets to distinguish grouping attributes and data attributes as following:
|
|
||||||
//
|
|
||||||
// argOffsets[0] is the length of grouping attributes
|
|
||||||
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
|
|
||||||
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
|
|
||||||
|
|
||||||
val dataAttributes = child.output.drop(groupingAttributes.length)
|
|
||||||
val groupingIndicesInData = groupingAttributes.map { attribute =>
|
|
||||||
dataAttributes.indexWhere(attribute.semanticEquals)
|
|
||||||
}
|
|
||||||
|
|
||||||
val groupingArgOffsets = new ArrayBuffer[Int]
|
|
||||||
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
|
|
||||||
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
|
|
||||||
|
|
||||||
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
|
|
||||||
// their offsets are 0, 1, 2 ...
|
|
||||||
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
|
|
||||||
// their offsets are n + index, where n is the total number of non duplicate grouping
|
|
||||||
// attributes and index is the index in the data attributes that the grouping attribute
|
|
||||||
// is a duplicate of.
|
|
||||||
|
|
||||||
groupingAttributes.zip(groupingIndicesInData).foreach {
|
|
||||||
case (attribute, index) =>
|
|
||||||
if (index == -1) {
|
|
||||||
groupingArgOffsets += nonDupGroupingAttributes.length
|
|
||||||
nonDupGroupingAttributes += attribute
|
|
||||||
} else {
|
|
||||||
groupingArgOffsets += index + nonDupGroupingSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val dataArgOffsets = nonDupGroupingAttributes.length until
|
|
||||||
(nonDupGroupingAttributes.length + dataAttributes.length)
|
|
||||||
|
|
||||||
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
|
|
||||||
|
|
||||||
// Attributes after deduplication
|
|
||||||
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
|
|
||||||
val dedupSchema = StructType.fromAttributes(dedupAttributes)
|
|
||||||
|
|
||||||
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
|
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
|
||||||
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
|
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
|
||||||
val grouped = if (groupingAttributes.isEmpty) {
|
|
||||||
Iterator(iter)
|
|
||||||
} else {
|
|
||||||
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
|
|
||||||
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
|
|
||||||
groupedIter.map {
|
|
||||||
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val context = TaskContext.get()
|
val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes)
|
||||||
|
.map{case(_, x) => x}
|
||||||
|
|
||||||
val columnarBatchIter = new ArrowPythonRunner(
|
val runner = new ArrowPythonRunner(
|
||||||
chainedFunc,
|
chainedFunc,
|
||||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||||
argOffsets,
|
Array(argOffsets),
|
||||||
dedupSchema,
|
StructType.fromAttributes(dedupAttributes),
|
||||||
sessionLocalTimeZone,
|
sessionLocalTimeZone,
|
||||||
pythonRunnerConf).compute(grouped, context.partitionId(), context)
|
pythonRunnerConf)
|
||||||
|
|
||||||
val unsafeProj = UnsafeProjection.create(output, output)
|
executePython(data, runner)
|
||||||
|
|
||||||
columnarBatchIter.flatMap { batch =>
|
|
||||||
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
|
|
||||||
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
|
|
||||||
val outputVectors = output.indices.map(structVector.getChild)
|
|
||||||
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
|
|
||||||
flattenedBatch.setNumRows(batch.numRows())
|
|
||||||
flattenedBatch.rowIterator.asScala
|
|
||||||
}.map(unsafeProj)
|
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue