[SPARK-28198][PYTHON] Add mapPartitionsInPandas to allow an iterator of DataFrames
## What changes were proposed in this pull request? This PR proposes to add `mapPartitionsInPandas` API to DataFrame by using existing `SCALAR_ITER` as below: 1. Filtering via setting the column ```python from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) pandas_udf(df.schema, PandasUDFType.SCALAR_ITER) def filter_func(iterator): for pdf in iterator: yield pdf[pdf.id == 1] df.mapPartitionsInPandas(filter_func).show() ``` ``` +---+---+ | id|age| +---+---+ | 1| 21| +---+---+ ``` 2. `DataFrame.loc` ```python from pyspark.sql.functions import pandas_udf, PandasUDFType import pandas as pd df = spark.createDataFrame([['aa'], ['bb'], ['cc'], ['aa'], ['aa'], ['aa']], ["value"]) pandas_udf(df.schema, PandasUDFType.SCALAR_ITER) def filter_func(iterator): for pdf in iterator: yield pdf.loc[pdf.value.str.contains('^a'), :] df.mapPartitionsInPandas(filter_func).show() ``` ``` +-----+ |value| +-----+ | aa| | aa| | aa| | aa| +-----+ ``` 3. `pandas.melt` ```python from pyspark.sql.functions import pandas_udf, PandasUDFType import pandas as pd df = spark.createDataFrame( pd.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'}, 'B': {0: 1, 1: 3, 2: 5}, 'C': {0: 2, 1: 4, 2: 6}})) pandas_udf("A string, variable string, value long", PandasUDFType.SCALAR_ITER) def filter_func(iterator): for pdf in iterator: import pandas as pd yield pd.melt(pdf, id_vars=['A'], value_vars=['B', 'C']) df.mapPartitionsInPandas(filter_func).show() ``` ``` +---+--------+-----+ | A|variable|value| +---+--------+-----+ | a| B| 1| | a| C| 2| | b| B| 3| | b| C| 4| | c| B| 5| | c| C| 6| +---+--------+-----+ ``` The current limitation of `SCALAR_ITER` is that it doesn't allow different length of result, which is pretty critical in practice - for instance, we cannot simply filter by using Pandas APIs but we merely just map N to N. This PR allows map N to M like flatMap. This API mimics the way of `mapPartitions` but keeps API shape of `SCALAR_ITER` by allowing different results. ### How does this PR implement? This PR adds mimics both `dapply` with Arrow optimization and Grouped Map Pandas UDF. At Python execution side, it reuses existing `SCALAR_ITER` code path. Therefore, externally, we don't introduce any new type of Pandas UDF but internally we use another evaluation type code `205` (`SQL_MAP_PANDAS_ITER_UDF`). This approach is similar with Pandas' Windows function implementation with Grouped Aggregation Pandas UDF functions - internally we have `203` (`SQL_WINDOW_AGG_PANDAS_UDF`) but externally we just share the same `GROUPED_AGG`. ## How was this patch tested? Manually tested and unittests were added. Closes #24997 from HyukjinKwon/scalar-udf-iter. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
0a4f985ca0
commit
02f4763286
|
@ -47,6 +47,7 @@ private[spark] object PythonEvalType {
|
|||
val SQL_GROUPED_AGG_PANDAS_UDF = 202
|
||||
val SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
val SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||
val SQL_MAP_PANDAS_ITER_UDF = 205
|
||||
|
||||
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
||||
case NON_UDF => "NON_UDF"
|
||||
|
@ -56,6 +57,7 @@ private[spark] object PythonEvalType {
|
|||
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
|
||||
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
|
||||
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
|
||||
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ class PythonEvalType(object):
|
|||
SQL_GROUPED_AGG_PANDAS_UDF = 202
|
||||
SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
SQL_SCALAR_PANDAS_ITER_UDF = 204
|
||||
SQL_MAP_PANDAS_ITER_UDF = 205
|
||||
|
||||
|
||||
def portable_hash(x):
|
||||
|
|
|
@ -28,7 +28,8 @@ else:
|
|||
import warnings
|
||||
|
||||
from pyspark import copy_func, since, _NoValue
|
||||
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, ignore_unicode_prefix
|
||||
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \
|
||||
ignore_unicode_prefix, PythonEvalType
|
||||
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
|
||||
UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
|
@ -2192,6 +2193,51 @@ class DataFrame(object):
|
|||
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
|
||||
return pdf
|
||||
|
||||
def mapPartitionsInPandas(self, udf):
|
||||
"""
|
||||
Maps each partition of the current :class:`DataFrame` using a pandas udf and returns
|
||||
the result as a `DataFrame`.
|
||||
|
||||
The user-defined function should take an iterator of `pandas.DataFrame`s and return another
|
||||
iterator of `pandas.DataFrame`s. For each partition, all columns are passed together as an
|
||||
iterator of `pandas.DataFrame`s to the user-function and the returned iterator of
|
||||
`pandas.DataFrame`s are combined as a :class:`DataFrame`.
|
||||
Each `pandas.DataFrame` size can be controlled by
|
||||
`spark.sql.execution.arrow.maxRecordsPerBatch`.
|
||||
Its schema must match the returnType of the pandas udf.
|
||||
|
||||
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
|
||||
... ("id", "age")) # doctest: +SKIP
|
||||
>>> @pandas_udf(df.schema, PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def filter_func(iterator):
|
||||
... for pdf in iterator:
|
||||
... yield pdf[pdf.id == 1]
|
||||
>>> df.mapPartitionsInPandas(filter_func).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id|age|
|
||||
+---+---+
|
||||
| 1| 21|
|
||||
+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"SCALAR_ITER.")
|
||||
|
||||
if not isinstance(udf.returnType, StructType):
|
||||
raise ValueError("The returnType of the pandas_udf must be a StructType")
|
||||
|
||||
udf_column = udf(*[self[col] for col in self.columns])
|
||||
jdf = self._jdf.mapPartitionsInPandas(udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
def _collectAsArrow(self):
|
||||
"""
|
||||
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
|
||||
|
|
135
python/pyspark/sql/tests/test_pandas_udf_iter.py
Normal file
135
python/pyspark/sql/tests/test_pandas_udf_iter.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
if sys.version >= '3':
|
||||
unicode = str
|
||||
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||
pandas_requirement_message, pyarrow_requirement_message
|
||||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not have_pandas or not have_pyarrow,
|
||||
pandas_requirement_message or pyarrow_requirement_message)
|
||||
class ScalarPandasIterUDFTests(ReusedSQLTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ReusedSQLTestCase.setUpClass()
|
||||
|
||||
# Synchronize default timezone between Python and Java
|
||||
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
|
||||
tz = "America/Los_Angeles"
|
||||
os.environ["TZ"] = tz
|
||||
time.tzset()
|
||||
|
||||
cls.sc.environment["TZ"] = tz
|
||||
cls.spark.conf.set("spark.sql.session.timeZone", tz)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del os.environ["TZ"]
|
||||
if cls.tz_prev is not None:
|
||||
os.environ["TZ"] = cls.tz_prev
|
||||
time.tzset()
|
||||
ReusedSQLTestCase.tearDownClass()
|
||||
|
||||
def test_map_partitions_in_pandas(self):
|
||||
@pandas_udf('id long', PandasUDFType.SCALAR_ITER)
|
||||
def func(iterator):
|
||||
for pdf in iterator:
|
||||
assert isinstance(pdf, pd.DataFrame)
|
||||
assert pdf.columns == ['id']
|
||||
yield pdf
|
||||
|
||||
df = self.spark.range(10)
|
||||
actual = df.mapPartitionsInPandas(func).collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
|
||||
def test_multiple_columns(self):
|
||||
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
|
||||
df = self.spark.createDataFrame(data, "a int, b string")
|
||||
|
||||
@pandas_udf(df.schema, PandasUDFType.SCALAR_ITER)
|
||||
def func(iterator):
|
||||
for pdf in iterator:
|
||||
assert isinstance(pdf, pd.DataFrame)
|
||||
assert [d.name for d in list(pdf.dtypes)] == ['int32', 'object']
|
||||
yield pdf
|
||||
|
||||
actual = df.mapPartitionsInPandas(func).collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
|
||||
def test_different_output_length(self):
|
||||
@pandas_udf('a long', PandasUDFType.SCALAR_ITER)
|
||||
def func(iterator):
|
||||
for _ in iterator:
|
||||
yield pd.DataFrame({'a': list(range(100))})
|
||||
|
||||
df = self.spark.range(10)
|
||||
actual = df.repartition(1).mapPartitionsInPandas(func).collect()
|
||||
self.assertEquals(set((r.a for r in actual)), set(range(100)))
|
||||
|
||||
def test_empty_iterator(self):
|
||||
@pandas_udf('a int, b string', PandasUDFType.SCALAR_ITER)
|
||||
def empty_iter(_):
|
||||
return iter([])
|
||||
|
||||
self.assertEqual(
|
||||
self.spark.range(10).mapPartitionsInPandas(empty_iter).count(), 0)
|
||||
|
||||
def test_empty_rows(self):
|
||||
@pandas_udf('a int', PandasUDFType.SCALAR_ITER)
|
||||
def empty_rows(_):
|
||||
return iter([pd.DataFrame({'a': []})])
|
||||
|
||||
self.assertEqual(
|
||||
self.spark.range(10).mapPartitionsInPandas(empty_rows).count(), 0)
|
||||
|
||||
def test_chain_map_partitions_in_pandas(self):
|
||||
@pandas_udf('id long', PandasUDFType.SCALAR_ITER)
|
||||
def func(iterator):
|
||||
for pdf in iterator:
|
||||
assert isinstance(pdf, pd.DataFrame)
|
||||
assert pdf.columns == ['id']
|
||||
yield pdf
|
||||
|
||||
df = self.spark.range(10)
|
||||
actual = df.mapPartitionsInPandas(func).mapPartitionsInPandas(func).collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.sql.tests.test_pandas_udf_iter import *
|
||||
|
||||
try:
|
||||
import xmlrunner
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
|
@ -86,7 +86,7 @@ def wrap_udf(f, return_type):
|
|||
return lambda *a: f(*a)
|
||||
|
||||
|
||||
def wrap_scalar_pandas_udf(f, return_type, eval_type):
|
||||
def wrap_scalar_pandas_udf(f, return_type):
|
||||
arrow_return_type = to_arrow_type(return_type)
|
||||
|
||||
def verify_result_type(result):
|
||||
|
@ -102,13 +102,22 @@ def wrap_scalar_pandas_udf(f, return_type, eval_type):
|
|||
"expected %d, got %d" % (length, len(result)))
|
||||
return result
|
||||
|
||||
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
|
||||
return lambda *a: (verify_result_length(
|
||||
verify_result_type(f(*a)), len(a[0])), arrow_return_type)
|
||||
else:
|
||||
# The result length verification is done at the end of a partition.
|
||||
return lambda *iterator: map(lambda res: (res, arrow_return_type),
|
||||
map(verify_result_type, f(*iterator)))
|
||||
return lambda *a: (verify_result_length(
|
||||
verify_result_type(f(*a)), len(a[0])), arrow_return_type)
|
||||
|
||||
|
||||
def wrap_pandas_iter_udf(f, return_type):
|
||||
arrow_return_type = to_arrow_type(return_type)
|
||||
|
||||
def verify_result_type(result):
|
||||
if not hasattr(result, "__len__"):
|
||||
pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
|
||||
raise TypeError("Return type of the user-defined function should be "
|
||||
"{}, but is {}".format(pd_type, type(result)))
|
||||
return result
|
||||
|
||||
return lambda *iterator: map(lambda res: (res, arrow_return_type),
|
||||
map(verify_result_type, f(*iterator)))
|
||||
|
||||
|
||||
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
|
||||
|
@ -226,9 +235,11 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
|
|||
|
||||
# the last returnType will be the return type of UDF
|
||||
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
|
||||
return arg_offsets, wrap_scalar_pandas_udf(func, return_type, eval_type)
|
||||
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
|
||||
return arg_offsets, wrap_scalar_pandas_udf(func, return_type, eval_type)
|
||||
return arg_offsets, wrap_pandas_iter_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
|
||||
return arg_offsets, wrap_pandas_iter_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
|
||||
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
|
||||
|
@ -247,6 +258,7 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
|
||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
|
||||
|
@ -270,7 +282,8 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
|
||||
# pandas Series. See SPARK-27240.
|
||||
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
|
||||
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
|
||||
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or
|
||||
eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
|
||||
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
|
||||
df_for_struct)
|
||||
else:
|
||||
|
@ -278,8 +291,11 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
|
||||
num_udfs = read_int(infile)
|
||||
|
||||
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
|
||||
assert num_udfs == 1, "One SQL_SCALAR_PANDAS_ITER_UDF expected here."
|
||||
is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
|
||||
is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
||||
|
||||
if is_scalar_iter or is_map_iter:
|
||||
assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
|
||||
|
||||
arg_offsets, udf = read_single_udf(
|
||||
pickleSer, infile, eval_type, runner_conf, udf_index=0)
|
||||
|
@ -301,20 +317,20 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
num_output_rows = 0
|
||||
for result_batch, result_type in result_iter:
|
||||
num_output_rows += len(result_batch)
|
||||
assert num_output_rows <= num_input_rows[0], \
|
||||
assert is_map_iter or num_output_rows <= num_input_rows[0], \
|
||||
"Pandas SCALAR_ITER UDF outputted more rows than input rows."
|
||||
yield (result_batch, result_type)
|
||||
try:
|
||||
if sys.version >= '3':
|
||||
iterator.__next__()
|
||||
else:
|
||||
iterator.next()
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input iterator.")
|
||||
|
||||
if num_output_rows != num_input_rows[0]:
|
||||
if is_scalar_iter:
|
||||
try:
|
||||
next(iterator)
|
||||
except StopIteration:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("SQL_SCALAR_PANDAS_ITER_UDF should exhaust the input "
|
||||
"iterator.")
|
||||
|
||||
if is_scalar_iter and num_output_rows != num_input_rows[0]:
|
||||
raise RuntimeError("The number of output rows of pandas iterator UDF should be "
|
||||
"the same with input rows. The input rows number is %d but the "
|
||||
"output rows number is %d." %
|
||||
|
|
|
@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas(
|
|||
override val producedAttributes = AttributeSet(output)
|
||||
}
|
||||
|
||||
/**
|
||||
* Map partitions using an udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame).
|
||||
* This is used by DataFrame.mapPartitionsInPandas()
|
||||
*/
|
||||
case class MapPartitionsInPandas(
|
||||
functionExpr: Expression,
|
||||
output: Seq[Attribute],
|
||||
child: LogicalPlan) extends UnaryNode {
|
||||
|
||||
override val producedAttributes = AttributeSet(output)
|
||||
}
|
||||
|
||||
trait BaseEvalPython extends UnaryNode {
|
||||
|
||||
def udfs: Seq[PythonUDF]
|
||||
|
|
|
@ -30,7 +30,7 @@ import org.apache.spark.TaskContext
|
|||
import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.java.function._
|
||||
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
|
||||
import org.apache.spark.api.python.{PythonEvalType, PythonRDD, SerDeUtil}
|
||||
import org.apache.spark.api.r.RRDD
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
@ -2643,6 +2643,25 @@ class Dataset[T] private[sql](
|
|||
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a Scalar iterator Pandas UDF to each partition. The user-defined function
|
||||
* defines a transformation: `iter(pandas.DataFrame)` -> `iter(pandas.DataFrame)`.
|
||||
* Each partition is each iterator consisting of DataFrames as batches.
|
||||
*
|
||||
* This function uses Apache Arrow as serialization format between Java executors and Python
|
||||
* workers.
|
||||
*/
|
||||
private[sql] def mapPartitionsInPandas(f: PythonUDF): DataFrame = {
|
||||
Dataset.ofRows(
|
||||
sparkSession,
|
||||
MapPartitionsInPandas(
|
||||
// Here, the evalType is SQL_SCALAR_PANDAS_ITER_UDF since we share the
|
||||
// same Pandas type. To avoid conflicts, it sets SQL_MAP_PANDAS_ITER_UDF here.
|
||||
f.copy(evalType = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF),
|
||||
f.dataType.asInstanceOf[StructType].toAttributes,
|
||||
logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* (Scala-specific)
|
||||
|
|
|
@ -682,6 +682,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
f, p, b, is, ot, planLater(child)) :: Nil
|
||||
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
|
||||
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
|
||||
case logical.MapPartitionsInPandas(func, output, child) =>
|
||||
execution.python.MapPartitionsInPandasExec(func, output, planLater(child)) :: Nil
|
||||
case logical.MapElements(f, _, _, objAttr, child) =>
|
||||
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
|
||||
case logical.AppendColumns(f, _, _, in, out, child) =>
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
|
||||
|
||||
/**
|
||||
* A relation produced by applying a function that takes an iterator of pandas DataFrames
|
||||
* and outputs an iterator of pandas DataFrames.
|
||||
*
|
||||
* This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
|
||||
* `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
|
||||
*
|
||||
*/
|
||||
case class MapPartitionsInPandasExec(
|
||||
func: Expression,
|
||||
output: Seq[Attribute],
|
||||
child: SparkPlan)
|
||||
extends UnaryExecNode {
|
||||
|
||||
private val pandasFunction = func.asInstanceOf[PythonUDF].func
|
||||
|
||||
override def producedAttributes: AttributeSet = AttributeSet(output)
|
||||
|
||||
private val batchSize = conf.arrowMaxRecordsPerBatch
|
||||
|
||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
child.execute().mapPartitionsInternal { inputIter =>
|
||||
// Single function with one struct.
|
||||
val argOffsets = Array(Array(0))
|
||||
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
|
||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
|
||||
val outputTypes = child.schema
|
||||
|
||||
// Here we wrap it via another row so that Python sides understand it
|
||||
// as a DataFrame.
|
||||
val wrappedIter = inputIter.map(InternalRow(_))
|
||||
|
||||
// DO NOT use iter.grouped(). See BatchIterator.
|
||||
val batchIter =
|
||||
if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter)
|
||||
|
||||
val context = TaskContext.get()
|
||||
|
||||
val columnarBatchIter = new ArrowPythonRunner(
|
||||
chainedFunc,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
argOffsets,
|
||||
StructType(StructField("struct", outputTypes) :: Nil),
|
||||
sessionLocalTimeZone,
|
||||
pythonRunnerConf).compute(batchIter, context.partitionId(), context)
|
||||
|
||||
val unsafeProj = UnsafeProjection.create(output, output)
|
||||
|
||||
columnarBatchIter.flatMap { batch =>
|
||||
// Scalar Iterator UDF returns a StructType column in ColumnarBatch, select
|
||||
// the children here
|
||||
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
|
||||
val outputVectors = output.indices.map(structVector.getChild)
|
||||
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
|
||||
flattenedBatch.setNumRows(batch.numRows())
|
||||
flattenedBatch.rowIterator.asScala
|
||||
}.map(unsafeProj)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue