[SPARK-30434][PYTHON][SQL] Move pandas related functionalities into 'pandas' sub-package

### What changes were proposed in this pull request?

This PR proposes to move pandas related functionalities into pandas package. Namely:

```bash
pyspark/sql/pandas
├── __init__.py
├── conversion.py  # Conversion between pandas <> PySpark DataFrames
├── functions.py   # pandas_udf
├── group_ops.py   # Grouped UDF / Cogrouped UDF + groupby.apply, groupby.cogroup.apply
├── map_ops.py     # Map Iter UDF + mapInPandas
├── serializers.py # pandas <> PyArrow serializers
├── types.py       # Type utils between pandas <> PyArrow
└── utils.py       # Version requirement checks
```

In order to separately locate `groupby.apply`, `groupby.cogroup.apply`, `mapInPandas`, `toPandas`, and `createDataFrame(pdf)` under `pandas` sub-package, I had to use a mix-in approach which Scala side uses often by `trait`, and also pandas itself uses this approach (see `IndexOpsMixin` as an example) to group related functionalities. Currently, you can think it's like Scala's self typed trait. See the structure below:

```python
class PandasMapOpsMixin(object):
    def mapInPandas(self, ...):
        ...
        return ...

    # other Pandas <> PySpark APIs
```

```python
class DataFrame(PandasMapOpsMixin):

    # other DataFrame APIs equivalent to Scala side.

```

Yes, This is a big PR but they are mostly just moving around except one case `createDataFrame` which I had to split the methods.

### Why are the changes needed?

There are pandas functionalities here and there and I myself gets lost where it was. Also, when you have to make a change commonly for all of pandas related features, it's almost impossible now.

Also, after this change, `DataFrame` and `SparkSession` become more consistent with Scala side since pandas is specific to Python, and this change separates pandas-specific APIs away from `DataFrame` or `SparkSession`.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Existing tests should cover. Also, I manually built the PySpark API documentation and checked.

Closes #27109 from HyukjinKwon/pandas-refactoring.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
HyukjinKwon 2020-01-09 10:22:50 +09:00
parent 18daa37cdb
commit ee8d661058
25 changed files with 1840 additions and 1514 deletions

View file

@ -362,6 +362,13 @@ pyspark_sql = Module(
"pyspark.sql.udf",
"pyspark.sql.window",
"pyspark.sql.avro.functions",
"pyspark.sql.pandas.conversion",
"pyspark.sql.pandas.map_ops",
"pyspark.sql.pandas.functions",
"pyspark.sql.pandas.group_ops",
"pyspark.sql.pandas.types",
"pyspark.sql.pandas.serializers",
"pyspark.sql.pandas.utils",
# unittests
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_catalog",

View file

@ -24,7 +24,7 @@ Run with:
from __future__ import print_function
from pyspark.sql import SparkSession
from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()

View file

@ -7,6 +7,7 @@ Module Context
.. automodule:: pyspark.sql
:members:
:undoc-members:
:inherited-members:
:exclude-members: builder
.. We need `exclude-members` to prevent default description generations
as a workaround for old Sphinx (< 1.6.6).

View file

@ -185,248 +185,6 @@ class FramedSerializer(Serializer):
raise NotImplementedError
class ArrowCollectSerializer(Serializer):
"""
Deserialize a stream of batches followed by batch order information. Used in
DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
"""
def __init__(self):
self.serializer = ArrowStreamSerializer()
def dump_stream(self, iterator, stream):
return self.serializer.dump_stream(iterator, stream)
def load_stream(self, stream):
"""
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
a list of indices that can be used to put the RecordBatches in the correct order.
"""
# load the batches
for batch in self.serializer.load_stream(stream):
yield batch
# load the batch order indices or propagate any error that occurred in the JVM
num = read_int(stream)
if num == -1:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while calling "
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order
def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer
class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
"""
def dump_stream(self, iterator, stream):
import pyarrow as pa
writer = None
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()
def load_stream(self, stream):
import pyarrow as pa
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch
def __repr__(self):
return "ArrowStreamSerializer"
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param timezone: A timezone to respect when handling timestamp values
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
"""
def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import _check_series_localize_timestamps
# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)
s = _check_series_localize_timestamps(s, self._timezone)
return s
def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array
arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, basestring)
for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name)
for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]
struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
batches = (self._create_batch(series) for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
for batch in batches:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
def __repr__(self):
return "ArrowStreamPandasSerializer"
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
"""
def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self) \
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types
if self._df_for_struct and types.is_struct(arrow_column.type):
import pandas as pd
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s
def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""
def init_stream_yield_batches():
should_write_start_length = True
for series in iterator:
batch = self._create_batch(series)
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
def __repr__(self):
return "ArrowStreamPandasUDFSerializer"
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
lists of pandas.Series.
"""
import pyarrow as pa
dataframes_in_group = None
while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)
if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
)
elif dataframes_in_group != 0:
raise ValueError(
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
class BatchedSerializer(Serializer):
"""

View file

@ -51,12 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
from pyspark.sql.cogroup import CoGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps
__all__ = [
'SparkSession', 'SQLContext', 'UDFRegistration',
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
'DataFrameReader', 'DataFrameWriter', 'PandasCogroupedOps'
]

View file

@ -31,8 +31,8 @@ import warnings
from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \
ignore_unicode_prefix, PythonEvalType
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@ -40,14 +40,14 @@ from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import IntegralType
from pyspark.sql.types import *
from pyspark.util import _exception_message
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
class DataFrame(object):
class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
"""A distributed collection of data grouped into named columns.
A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
@ -2135,193 +2135,6 @@ class DataFrame(object):
"should have been DataFrame." % type(result)
return result
@since(1.3)
def toPandas(self):
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
This is only available if Pandas is installed and available.
.. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
expected to be small, as all the data is loaded into the driver's memory.
.. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
>>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import numpy as np
import pandas as pd
if self.sql_ctx._conf.pandasRespectSessionTimeZone():
timezone = self.sql_ctx._conf.sessionLocalTimeZone()
else:
timezone = None
if self.sql_ctx._conf.arrowPySparkEnabled():
use_arrow = True
try:
from pyspark.sql.types import to_arrow_schema
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
to_arrow_schema(self.schema)
except Exception as e:
if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempting non-optimization as "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
use_arrow = False
else:
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and will not continue because automatic fallback "
"with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
"false.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
# Try to use Arrow optimization when the schema is supported and the required version
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
# values, but we should use datetime.date to match the behavior with when
# Arrow optimization is disabled.
pdf = table.to_pandas(date_as_object=True)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except Exception as e:
# We might have to allow fallback here as well but multiple Spark jobs can
# be executed. So, simply fail in this case for now.
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and can not continue. Note that "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
"effect on failures in the middle of "
"computation.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type
# Ensure we fall back to nullable numpy types, even when whole column is null:
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
dtype[field.name] = np.float64
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
dtype[field.name] = np.object
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
if timezone is None:
return pdf
else:
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf
def mapInPandas(self, udf):
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
function and returns the result as a :class:`DataFrame`.
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return
another iterator of `pandas.DataFrame`\\s. All columns are passed
together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the
returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
Each `pandas.DataFrame` size can be controlled by
`spark.sql.execution.arrow.maxRecordsPerBatch`.
Its schema must match the returnType of the Pandas user-defined function.
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
... ("id", "age")) # doctest: +SKIP
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
... def filter_func(batch_iter):
... for pdf in batch_iter:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"MAP_ITER.")
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)
def _collectAsArrow(self):
"""
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
and available on driver and worker Python environments.
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
# Collect list of un-ordered batches where last element is a list of correct order indices
try:
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
finally:
# Join serving thread and raise any exceptions from collectAsArrowToPython
jsocket_auth_server.getResult()
# Separate RecordBatches from batch order indices in results
batches = results[:-1]
batch_order = results[-1]
# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]
##########################################################################################
# Pandas compatibility
##########################################################################################
@ -2349,33 +2162,6 @@ def _to_scala_map(sc, jm):
return sc._jvm.PythonUtils.toScalaMap(jm)
def _to_corrected_pandas_type(dt):
"""
When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type may be
wrong. This method gets the corrected data type for Pandas if that type may be inferred
uncorrectly.
"""
import numpy as np
if type(dt) == ByteType:
return np.int8
elif type(dt) == ShortType:
return np.int16
elif type(dt) == IntegerType:
return np.int32
elif type(dt) == LongType:
return np.int64
elif type(dt) == FloatType:
return np.float32
elif type(dt) == DoubleType:
return np.float64
elif type(dt) == BooleanType:
return np.bool
elif type(dt) == TimestampType:
return np.datetime64
else:
return None
class DataFrameNaFunctions(object):
"""Functionality for working with missing data in :class:`DataFrame`.

View file

@ -36,6 +36,8 @@ from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf
# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
from pyspark.sql.utils import to_str
# Note to developers: all of PySpark functions here take string as column names whenever possible.
@ -2814,22 +2816,6 @@ def from_csv(col, schema, options={}):
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
"""Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
"""
SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a user defined function (UDF).
@ -2917,483 +2903,6 @@ def udf(f=None, returnType=StringType()):
evalType=PythonEvalType.SQL_BATCHED_UDF)
@since(2.3)
def pandas_udf(f=None, returnType=None, functionType=None):
"""
Creates a vectorized user defined function (UDF).
:param f: user-defined function. A python function if used as a standalone function
:param returnType: the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
Default: SCALAR.
The function type of the UDF can be one of the following:
1. SCALAR
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql.types import IntegerType, StringType
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP
>>> @pandas_udf(StringType()) # doctest: +SKIP
... def to_upper(s):
... return s.str.upper()
...
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)],
... ("id", "name", "age")) # doctest: +SKIP
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
... .show() # doctest: +SKIP
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
>>> @pandas_udf("first string, last string") # doctest: +SKIP
... def split_expand(n):
... return n.str.split(expand=True)
>>> df.select(split_expand("name")).show() # doctest: +SKIP
+------------------+
|split_expand(name)|
+------------------+
| [John, Doe]|
+------------------+
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
Therefore, this can be used, for example, to ensure the length of each returned
`pandas.Series`, and can not be used as the column length.
2. SCALAR_ITER
A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the
wrapped Python function takes an iterator of batches as input instead of a single batch and,
instead of returning a single output batch, it yields output batches or explicitly returns an
generator or an iterator of output batches.
It is useful when the UDF execution requires initializing some state, e.g., loading a machine
learning model file to apply inference to every input batch.
.. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all
batches from one partition, although it is currently implemented this way.
Your code shall not rely on this behavior because it might change in the future for
further optimization, e.g., one invocation processes multiple partitions.
Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType
>>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP
>>> df = spark.createDataFrame(pdf) # doctest: +SKIP
When the UDF is called with a single column that is not `StructType`, the input to the
underlying function is an iterator of `pd.Series`.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def plus_one(batch_iter):
... for x in batch_iter:
... yield x + 1
...
>>> df.select(plus_one(col("x"))).show() # doctest: +SKIP
+-----------+
|plus_one(x)|
+-----------+
| 2|
| 3|
| 4|
+-----------+
When the UDF is called with more than one columns, the input to the underlying function is an
iterator of `pd.Series` tuple.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def multiply_two_cols(batch_iter):
... for a, b in batch_iter:
... yield a * b
...
>>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP
+-----------------------+
|multiply_two_cols(x, x)|
+-----------------------+
| 1|
| 4|
| 9|
+-----------------------+
When the UDF is called with a single column that is `StructType`, the input to the underlying
function is an iterator of `pd.DataFrame`.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def multiply_two_nested_cols(pdf_iter):
... for pdf in pdf_iter:
... yield pdf["a"] * pdf["b"]
...
>>> df.select(
... multiply_two_nested_cols(
... struct(col("x").alias("a"), col("x").alias("b"))
... ).alias("y")
... ).show() # doctest: +SKIP
+---+
| y|
+---+
| 1|
| 4|
| 9|
+---+
In the UDF, you can initialize some states before processing batches, wrap your code with
`try ... finally ...` or use context managers to ensure the release of resources at the end
or in case of early termination.
>>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def plus_y(batch_iter):
... y = y_bc.value # initialize some state
... try:
... for x in batch_iter:
... yield x + y
... finally:
... pass # release resources here, if any
...
>>> df.select(plus_y(col("x"))).show() # doctest: +SKIP
+---------+
|plus_y(x)|
+---------+
| 2|
| 3|
| 4|
+---------+
3. GROUPED_MAP
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
the field names in the defined returnType schema if specified as strings, or match the
field data types by position if not strings, e.g. integer indices.
The length of the returned `pandas.DataFrame` can be arbitrary.
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
Alternatively, the user can define a function that takes two arguments.
In this case, the grouping key(s) will be passed as the first argument and the data will
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
This is useful when the user does not want to hardcode grouping key(s) in the function.
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def mean_udf(key, pdf):
... # key is a tuple of one numpy.int64, which is the value
... # of 'id' for the current group
... return pd.DataFrame([key + (pdf.v.mean(),)])
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+---+---+
| id| v|
+---+---+
| 1|1.5|
| 2|6.0|
+---+---+
>>> @pandas_udf(
... "id long, `ceil(v / 2)` long, v double",
... PandasUDFType.GROUPED_MAP) # doctest: +SKIP
>>> def sum_udf(key, pdf):
... # key is a tuple of two numpy.int64s, which is the values
... # of 'id' and 'ceil(df.v / 2)' for the current group
... return pd.DataFrame([key + (pdf.v.sum(),)])
>>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP
+---+-----------+----+
| id|ceil(v / 2)| v|
+---+-----------+----+
| 2| 5|10.0|
| 1| 1| 3.0|
| 2| 3| 5.0|
| 2| 2| 3.0|
+---+-----------+----+
.. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
recommended to explicitly index the columns by name to ensure the positions are correct,
or alternatively use an `OrderedDict`.
For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
`pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
4. GROUPED_AGG
A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
:class:`MapType` and :class:`StructType` are currently not supported as output types.
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
:class:`pyspark.sql.Window`
This example shows using grouped aggregated UDFs with groupby:
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
+---+-----------+
| id|mean_udf(v)|
+---+-----------+
| 1| 1.5|
| 2| 6.0|
+---+-----------+
This example shows using grouped aggregated UDFs as window functions.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql import Window
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> w = (Window.partitionBy('id')
... .orderBy('v')
... .rowsBetween(-1, 0))
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
+---+----+------+
| 1| 1.0| 1.0|
| 1| 2.0| 1.5|
| 2| 3.0| 3.0|
| 2| 5.0| 4.0|
| 2|10.0| 7.5|
+---+----+------+
.. note:: For performance reasons, the input series to window functions are not copied.
Therefore, mutating the input series is not allowed and will cause incorrect results.
For the same reason, users should also not rely on the index of the input series.
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
5. MAP_ITER
A map iterator Pandas UDFs are used to transform data with an iterator of batches.
It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`.
It can return the output of arbitrary length in contrast to the scalar Pandas UDF.
It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
function and returns the result as a :class:`DataFrame`.
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another
iterator of `pandas.DataFrame`\\s. All columns are passed together as an
iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of
`pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
... ("id", "age")) # doctest: +SKIP
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
... def filter_func(batch_iter):
... for pdf in batch_iter:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
6. COGROUPED_MAP
A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) ->
`pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema
of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame`
must either match the field names in the defined `returnType` schema if specified as strings,
or match the field data types by position if not strings, e.g. integer indices. The length
of the returned `pandas.DataFrame` can be arbitrary.
CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+---------+---+---+---+
| time| id| v1| v2|
+---------+---+---+---+
| 20000101| 1|1.0| x|
| 20000102| 1|3.0| x|
| 20000101| 2|2.0| y|
| 20000102| 2|4.0| y|
+---------+---+---+---+
Alternatively, the user can define a function that takes three arguments. In this case,
the grouping key(s) will be passed as the first argument and the data will be passed as the
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(k, l, r):
... if k == (1,):
... return pd.merge_asof(l, r, on="time", by="id")
... else:
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+---------+---+---+---+
| time| id| v1| v2|
+---------+---+---+---+
| 20000101| 1|1.0| x|
| 20000102| 1|3.0| x|
+---------+---+---+---+
.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked
more times than it is present in the query. If your function is not deterministic, call
`asNondeterministic` on the user defined function. E.g.:
>>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP
... def random(v):
... import numpy as np
... import pandas as pd
... return pd.Series(np.random.randn(len(v))
>>> random = random.asNondeterministic() # doctest: +SKIP
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
.. note:: The user-defined functions do not take keyword arguments on the calling side.
.. note:: The data type of returned `pandas.Series` from the user-defined functions should be
matched with defined returnType (see :meth:`types.to_arrow_type` and
:meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do
conversion on returned data. The conversion is not guaranteed to be correct and results
should be checked for accuracy by users.
"""
# The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
# future. The table might have to be eventually documented externally.
# Please see SPARK-28132's PR to see the codes in order to generate the table below.
#
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa
# | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa
# | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
# | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
# | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa
# | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
# | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
# | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa
# | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa
# | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa
# | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa
# | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa
# | map<string,int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
#
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
# used in `returnType`.
# Note: The values inside of the table are generated by `repr`.
# Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used.
# Note: Timezone is KST.
# Note: 'X' means it throws an exception during the conversion.
# decorator @pandas_udf(returnType, functionType)
is_decorator = f is None or isinstance(f, (str, DataType))
if is_decorator:
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
if functionType is not None:
# @pandas_udf(dataType, functionType=functionType)
# @pandas_udf(returnType=dataType, functionType=functionType)
eval_type = functionType
elif returnType is not None and isinstance(returnType, int):
# @pandas_udf(dataType, functionType)
eval_type = returnType
else:
# @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
else:
return_type = returnType
if functionType is not None:
eval_type = functionType
else:
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
if return_type is None:
raise ValueError("Invalid returnType: returnType can not be None")
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")
if is_decorator:
return functools.partial(_create_udf, returnType=return_type, evalType=eval_type)
else:
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
blacklist = ['map', 'since', 'ignore_unicode_prefix']
__all__ = [k for k, v in globals().items()
if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]

View file

@ -18,11 +18,11 @@
import sys
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
from pyspark.sql.types import *
from pyspark.sql.cogroup import CoGroupedData
__all__ = ["GroupedData"]
@ -47,7 +47,7 @@ def df_varargs_api(f):
return _api
class GroupedData(object):
class GroupedData(PandasGroupedOpsMixin):
"""
A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
@ -219,68 +219,6 @@ class GroupedData(object):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self._df)
@since(3.0)
def cogroup(self, other):
"""
Cogroups this group with another group so that we can run cogrouped operations.
See :class:`CoGroupedData` for the operations that can be run.
"""
return CoGroupedData(self, other)
@since(2.3)
def apply(self, udf):
"""
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
as a `DataFrame`.
The user-defined function should take a `pandas.DataFrame` and return another
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame` are combined as a
:class:`DataFrame`.
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
.. note:: This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
:param udf: a grouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"GROUPED_MAP.")
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)
def _test():
import doctest

View file

@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This package includes the internal APIs for PySpark about interoperability
between pandas, PySpark and PyArrow. This package should not be directly
imported and used.
"""

View file

@ -0,0 +1,431 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import warnings
if sys.version >= '3':
basestring = unicode = str
xrange = range
else:
from itertools import izip as zip
from pyspark import since
from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import *
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import _exception_message
class PandasConversionMixin(object):
"""
Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame`
can use this class.
"""
@since(1.3)
def toPandas(self):
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
This is only available if Pandas is installed and available.
.. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
expected to be small, as all the data is loaded into the driver's memory.
.. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
>>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
"""
from pyspark.sql.dataframe import DataFrame
assert isinstance(self, DataFrame)
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import numpy as np
import pandas as pd
if self.sql_ctx._conf.pandasRespectSessionTimeZone():
timezone = self.sql_ctx._conf.sessionLocalTimeZone()
else:
timezone = None
if self.sql_ctx._conf.arrowPySparkEnabled():
use_arrow = True
try:
from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
to_arrow_schema(self.schema)
except Exception as e:
if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempting non-optimization as "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
use_arrow = False
else:
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and will not continue because automatic fallback "
"with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
"false.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
# Try to use Arrow optimization when the schema is supported and the required version
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.pandas.types import _check_dataframe_localize_timestamps
import pyarrow
batches = self._collect_as_arrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
# values, but we should use datetime.date to match the behavior with when
# Arrow optimization is disabled.
pdf = table.to_pandas(date_as_object=True)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except Exception as e:
# We might have to allow fallback here as well but multiple Spark jobs can
# be executed. So, simply fail in this case for now.
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and can not continue. Note that "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
"effect on failures in the middle of "
"computation.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
dtype = {}
for field in self.schema:
pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
pdf[field.name].isnull().any()):
dtype[field.name] = pandas_type
# Ensure we fall back to nullable numpy types, even when whole column is null:
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
dtype[field.name] = np.float64
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
dtype[field.name] = np.object
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
if timezone is None:
return pdf
else:
from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
return pdf
@staticmethod
def _to_corrected_pandas_type(dt):
"""
When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type
may be wrong. This method gets the corrected data type for Pandas if that type may be
inferred incorrectly.
"""
import numpy as np
if type(dt) == ByteType:
return np.int8
elif type(dt) == ShortType:
return np.int16
elif type(dt) == IntegerType:
return np.int32
elif type(dt) == LongType:
return np.int64
elif type(dt) == FloatType:
return np.float32
elif type(dt) == DoubleType:
return np.float64
elif type(dt) == BooleanType:
return np.bool
elif type(dt) == TimestampType:
return np.datetime64
else:
return None
def _collect_as_arrow(self):
"""
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
and available on driver and worker Python environments.
.. note:: Experimental.
"""
from pyspark.sql.dataframe import DataFrame
assert isinstance(self, DataFrame)
with SCCallSiteSync(self._sc):
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
# Collect list of un-ordered batches where last element is a list of correct order indices
try:
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
finally:
# Join serving thread and raise any exceptions from collectAsArrowToPython
jsocket_auth_server.getResult()
# Separate RecordBatches from batch order indices in results
batches = results[:-1]
batch_order = results[-1]
# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]
class SparkConversionMixin(object):
"""
Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
can use this class.
"""
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
from pyspark.sql import SparkSession
assert isinstance(self, SparkSession)
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
if self._wrapped._conf.pandasRespectSessionTimeZone():
timezone = self._wrapped._conf.sessionLocalTimeZone()
else:
timezone = None
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in data.columns]
if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
from pyspark.util import _exception_message
if self._wrapped._conf.arrowPySparkFallbackEnabled():
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempting non-optimization as "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
else:
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and will not continue because automatic "
"fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
"has been set to false.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
data = self._convert_from_pandas(data, schema, timezone)
return self._create_dataframe(data, schema, samplingRatio, samplingRatio)
def _convert_from_pandas(self, pdf, schema, timezone):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
:return list of records
"""
from pyspark.sql import SparkSession
assert isinstance(self, SparkSession)
if timezone is not None:
from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
copied = False
if isinstance(schema, StructType):
for field in schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if s is not pdf[field.name]:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s
else:
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(series, timezone)
if s is not series:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[column] = s
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
# Check if any columns need to be fixed for Spark to infer properly
if len(np_records) > 0:
record_dtype = self._get_numpy_record_dtype(np_records[0])
if record_dtype is not None:
return [r.astype(record_dtype).tolist() for r in np_records]
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records]
def _get_numpy_record_dtype(self, rec):
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
the dtypes of fields in a record so they can be properly loaded into Spark.
:param rec: a numpy record to check field dtypes
:return corrected dtype for a numpy.record or None if no correction needed
"""
import numpy as np
cur_dtypes = rec.dtype
col_names = cur_dtypes.names
record_type_list = []
has_rec_fix = False
for i in xrange(len(cur_dtypes)):
curr_type = cur_dtypes[i]
# If type is a datetime64 timestamp, convert to microseconds
# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,
# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417
if curr_type == np.dtype('datetime64[ns]'):
curr_type = 'datetime64[us]'
has_rec_fix = True
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None
def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
assert isinstance(self, SparkSession)
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import TimestampType
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
import pyarrow as pa
# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
schema = struct
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create list of Arrow (columns, type) for serializer dump_stream
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]
jsqlContext = self._wrapped._jsqlContext
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)
# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.conversion
globs = pyspark.sql.pandas.conversion.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.pandas.conversion tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.pandas.conversion, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()

View file

@ -0,0 +1,539 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import sys
from pyspark import since
from pyspark.rdd import PythonEvalType
from pyspark.sql.types import DataType
from pyspark.sql.udf import _create_udf
class PandasUDFType(object):
"""Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
"""
SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
@since(2.3)
def pandas_udf(f=None, returnType=None, functionType=None):
"""
Creates a vectorized user defined function (UDF).
:param f: user-defined function. A python function if used as a standalone function
:param returnType: the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
Default: SCALAR.
The function type of the UDF can be one of the following:
1. SCALAR
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql.types import IntegerType, StringType
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP
>>> @pandas_udf(StringType()) # doctest: +SKIP
... def to_upper(s):
... return s.str.upper()
...
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)],
... ("id", "name", "age")) # doctest: +SKIP
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
... .show() # doctest: +SKIP
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
>>> @pandas_udf("first string, last string") # doctest: +SKIP
... def split_expand(n):
... return n.str.split(expand=True)
>>> df.select(split_expand("name")).show() # doctest: +SKIP
+------------------+
|split_expand(name)|
+------------------+
| [John, Doe]|
+------------------+
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
Therefore, this can be used, for example, to ensure the length of each returned
`pandas.Series`, and can not be used as the column length.
2. SCALAR_ITER
A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the
wrapped Python function takes an iterator of batches as input instead of a single batch and,
instead of returning a single output batch, it yields output batches or explicitly returns an
generator or an iterator of output batches.
It is useful when the UDF execution requires initializing some state, e.g., loading a machine
learning model file to apply inference to every input batch.
.. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all
batches from one partition, although it is currently implemented this way.
Your code shall not rely on this behavior because it might change in the future for
further optimization, e.g., one invocation processes multiple partitions.
Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType
>>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP
>>> df = spark.createDataFrame(pdf) # doctest: +SKIP
When the UDF is called with a single column that is not `StructType`, the input to the
underlying function is an iterator of `pd.Series`.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def plus_one(batch_iter):
... for x in batch_iter:
... yield x + 1
...
>>> df.select(plus_one(col("x"))).show() # doctest: +SKIP
+-----------+
|plus_one(x)|
+-----------+
| 2|
| 3|
| 4|
+-----------+
When the UDF is called with more than one columns, the input to the underlying function is an
iterator of `pd.Series` tuple.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def multiply_two_cols(batch_iter):
... for a, b in batch_iter:
... yield a * b
...
>>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP
+-----------------------+
|multiply_two_cols(x, x)|
+-----------------------+
| 1|
| 4|
| 9|
+-----------------------+
When the UDF is called with a single column that is `StructType`, the input to the underlying
function is an iterator of `pd.DataFrame`.
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def multiply_two_nested_cols(pdf_iter):
... for pdf in pdf_iter:
... yield pdf["a"] * pdf["b"]
...
>>> df.select(
... multiply_two_nested_cols(
... struct(col("x").alias("a"), col("x").alias("b"))
... ).alias("y")
... ).show() # doctest: +SKIP
+---+
| y|
+---+
| 1|
| 4|
| 9|
+---+
In the UDF, you can initialize some states before processing batches, wrap your code with
`try ... finally ...` or use context managers to ensure the release of resources at the end
or in case of early termination.
>>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
... def plus_y(batch_iter):
... y = y_bc.value # initialize some state
... try:
... for x in batch_iter:
... yield x + y
... finally:
... pass # release resources here, if any
...
>>> df.select(plus_y(col("x"))).show() # doctest: +SKIP
+---------+
|plus_y(x)|
+---------+
| 2|
| 3|
| 4|
+---------+
3. GROUPED_MAP
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
the field names in the defined returnType schema if specified as strings, or match the
field data types by position if not strings, e.g. integer indices.
The length of the returned `pandas.DataFrame` can be arbitrary.
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
Alternatively, the user can define a function that takes two arguments.
In this case, the grouping key(s) will be passed as the first argument and the data will
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
This is useful when the user does not want to hardcode grouping key(s) in the function.
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def mean_udf(key, pdf):
... # key is a tuple of one numpy.int64, which is the value
... # of 'id' for the current group
... return pd.DataFrame([key + (pdf.v.mean(),)])
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+---+---+
| id| v|
+---+---+
| 1|1.5|
| 2|6.0|
+---+---+
>>> @pandas_udf(
... "id long, `ceil(v / 2)` long, v double",
... PandasUDFType.GROUPED_MAP) # doctest: +SKIP
>>> def sum_udf(key, pdf):
... # key is a tuple of two numpy.int64s, which is the values
... # of 'id' and 'ceil(df.v / 2)' for the current group
... return pd.DataFrame([key + (pdf.v.sum(),)])
>>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP
+---+-----------+----+
| id|ceil(v / 2)| v|
+---+-----------+----+
| 2| 5|10.0|
| 1| 1| 3.0|
| 2| 3| 5.0|
| 2| 2| 3.0|
+---+-----------+----+
.. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
recommended to explicitly index the columns by name to ensure the positions are correct,
or alternatively use an `OrderedDict`.
For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
`pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
4. GROUPED_AGG
A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
:class:`MapType` and :class:`StructType` are currently not supported as output types.
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
:class:`pyspark.sql.Window`
This example shows using grouped aggregated UDFs with groupby:
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
+---+-----------+
| id|mean_udf(v)|
+---+-----------+
| 1| 1.5|
| 2| 6.0|
+---+-----------+
This example shows using grouped aggregated UDFs as window functions.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql import Window
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> w = (Window.partitionBy('id')
... .orderBy('v')
... .rowsBetween(-1, 0))
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
+---+----+------+
| 1| 1.0| 1.0|
| 1| 2.0| 1.5|
| 2| 3.0| 3.0|
| 2| 5.0| 4.0|
| 2|10.0| 7.5|
+---+----+------+
.. note:: For performance reasons, the input series to window functions are not copied.
Therefore, mutating the input series is not allowed and will cause incorrect results.
For the same reason, users should also not rely on the index of the input series.
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
5. MAP_ITER
A map iterator Pandas UDFs are used to transform data with an iterator of batches.
It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`.
It can return the output of arbitrary length in contrast to the scalar Pandas UDF.
It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
function and returns the result as a :class:`DataFrame`.
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another
iterator of `pandas.DataFrame`\\s. All columns are passed together as an
iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of
`pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
... ("id", "age")) # doctest: +SKIP
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
... def filter_func(batch_iter):
... for pdf in batch_iter:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
6. COGROUPED_MAP
A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) ->
`pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema
of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame`
must either match the field names in the defined `returnType` schema if specified as strings,
or match the field data types by position if not strings, e.g. integer indices. The length
of the returned `pandas.DataFrame` can be arbitrary.
CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+---------+---+---+---+
| time| id| v1| v2|
+---------+---+---+---+
| 20000101| 1|1.0| x|
| 20000102| 1|3.0| x|
| 20000101| 2|2.0| y|
| 20000102| 2|4.0| y|
+---------+---+---+---+
Alternatively, the user can define a function that takes three arguments. In this case,
the grouping key(s) will be passed as the first argument and the data will be passed as the
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(k, l, r):
... if k == (1,):
... return pd.merge_asof(l, r, on="time", by="id")
... else:
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+---------+---+---+---+
| time| id| v1| v2|
+---------+---+---+---+
| 20000101| 1|1.0| x|
| 20000102| 1|3.0| x|
+---------+---+---+---+
.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked
more times than it is present in the query. If your function is not deterministic, call
`asNondeterministic` on the user defined function. E.g.:
>>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP
... def random(v):
... import numpy as np
... import pandas as pd
... return pd.Series(np.random.randn(len(v))
>>> random = random.asNondeterministic() # doctest: +SKIP
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
.. note:: The user-defined functions do not take keyword arguments on the calling side.
.. note:: The data type of returned `pandas.Series` from the user-defined functions should be
matched with defined returnType (see :meth:`types.to_arrow_type` and
:meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do
conversion on returned data. The conversion is not guaranteed to be correct and results
should be checked for accuracy by users.
"""
# The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
# future. The table might have to be eventually documented externally.
# Please see SPARK-28132's PR to see the codes in order to generate the table below.
#
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
# | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa
# | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa
# | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
# | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
# | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa
# | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
# | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
# | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa
# | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa
# | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa
# | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa
# | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa
# | map<string,int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
# | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
#
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
# used in `returnType`.
# Note: The values inside of the table are generated by `repr`.
# Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used.
# Note: Timezone is KST.
# Note: 'X' means it throws an exception during the conversion.
# decorator @pandas_udf(returnType, functionType)
is_decorator = f is None or isinstance(f, (str, DataType))
if is_decorator:
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
if functionType is not None:
# @pandas_udf(dataType, functionType=functionType)
# @pandas_udf(returnType=dataType, functionType=functionType)
eval_type = functionType
elif returnType is not None and isinstance(returnType, int):
# @pandas_udf(dataType, functionType)
eval_type = returnType
else:
# @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
else:
return_type = returnType
if functionType is not None:
eval_type = functionType
else:
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
if return_type is None:
raise ValueError("Invalid returnType: returnType can not be None")
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")
if is_decorator:
return functools.partial(_create_udf, returnType=return_type, evalType=eval_type)
else:
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.functions
globs = pyspark.sql.pandas.functions.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.pandas.functions tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.pandas.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()

View file

@ -22,7 +22,83 @@ from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
class CoGroupedData(object):
class PandasGroupedOpsMixin(object):
"""
Min-in for pandas grouped operations. Currently, only :class:`GroupedData`
can use this class.
"""
def apply(self, udf):
"""
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
as a `DataFrame`.
The user-defined function should take a `pandas.DataFrame` and return another
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
to the user-function and the returned `pandas.DataFrame` are combined as a
:class:`DataFrame`.
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
.. note:: This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
:param udf: a grouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
from pyspark.sql import GroupedData
assert isinstance(self, GroupedData)
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"GROUPED_MAP.")
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)
@since(3.0)
def cogroup(self, other):
"""
Cogroups this group with another group so that we can run cogrouped operations.
See :class:`CoGroupedData` for the operations that can be run.
"""
from pyspark.sql import GroupedData
assert isinstance(self, GroupedData)
return PandasCogroupedOps(self, other)
class PandasCogroupedOps(object):
"""
A logical grouping of two :class:`GroupedData`,
created by :func:`GroupedData.cogroup`.
@ -124,15 +200,15 @@ class CoGroupedData(object):
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.cogroup
globs = pyspark.sql.cogroup.__dict__.copy()
import pyspark.sql.pandas.group_ops
globs = pyspark.sql.pandas.group_ops.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.cogroup tests")\
.appName("sql.pandas.group tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.cogroup, globs=globs,
pyspark.sql.pandas.group_ops, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:

View file

@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from pyspark import since
from pyspark.rdd import PythonEvalType
class PandasMapOpsMixin(object):
"""
Min-in for pandas map operations. Currently, only :class:`DataFrame`
can use this class.
"""
@since(3.0)
def mapInPandas(self, udf):
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
function and returns the result as a :class:`DataFrame`.
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return
another iterator of `pandas.DataFrame`\\s. All columns are passed
together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the
returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
Each `pandas.DataFrame` size can be controlled by
`spark.sql.execution.arrow.maxRecordsPerBatch`.
Its schema must match the returnType of the Pandas user-defined function.
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
... ("id", "age")) # doctest: +SKIP
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
... def filter_func(batch_iter):
... for pdf in batch_iter:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
from pyspark.sql import Column, DataFrame
assert isinstance(self, DataFrame)
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"MAP_ITER.")
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.map_ops
globs = pyspark.sql.pandas.map_ops.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.pandas.map_ops tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.pandas.map_ops, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()

View file

@ -0,0 +1,281 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
"""
import sys
if sys.version < '3':
from itertools import izip as zip
else:
basestring = unicode = str
xrange = range
from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer
class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
END_OF_STREAM = -4
NULL = -5
START_ARROW_STREAM = -6
class ArrowCollectSerializer(Serializer):
"""
Deserialize a stream of batches followed by batch order information. Used in
PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython()
in the JVM.
"""
def __init__(self):
self.serializer = ArrowStreamSerializer()
def dump_stream(self, iterator, stream):
return self.serializer.dump_stream(iterator, stream)
def load_stream(self, stream):
"""
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
a list of indices that can be used to put the RecordBatches in the correct order.
"""
# load the batches
for batch in self.serializer.load_stream(stream):
yield batch
# load the batch order indices or propagate any error that occurred in the JVM
num = read_int(stream)
if num == -1:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while calling "
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order
def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer
class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
"""
def dump_stream(self, iterator, stream):
import pyarrow as pa
writer = None
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()
def load_stream(self, stream):
import pyarrow as pa
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch
def __repr__(self):
return "ArrowStreamSerializer"
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param timezone: A timezone to respect when handling timestamp values
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
"""
def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import _check_series_localize_timestamps
# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)
s = _check_series_localize_timestamps(s, self._timezone)
return s
def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array
arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, basestring)
for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name)
for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]
struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
batches = (self._create_batch(series) for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
for batch in batches:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
def __repr__(self):
return "ArrowStreamPandasSerializer"
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
"""
def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self) \
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types
if self._df_for_struct and types.is_struct(arrow_column.type):
import pandas as pd
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s
def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""
def init_stream_yield_batches():
should_write_start_length = True
for series in iterator:
batch = self._create_batch(series)
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
def __repr__(self):
return "ArrowStreamPandasUDFSerializer"
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
lists of pandas.Series.
"""
import pyarrow as pa
dataframes_in_group = None
while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)
if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
)
elif dataframes_in_group != 0:
raise ValueError(
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))

View file

@ -0,0 +1,284 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Type-specific codes between pandas and PyArrow. Also contains some utils to correct
pandas instances during the type conversion.
"""
from pyspark.sql.types import *
def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
elif type(dt) == ByteType:
arrow_type = pa.int8()
elif type(dt) == ShortType:
arrow_type = pa.int16()
elif type(dt) == IntegerType:
arrow_type = pa.int32()
elif type(dt) == LongType:
arrow_type = pa.int64()
elif type(dt) == FloatType:
arrow_type = pa.float32()
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
spark_type = ShortType()
elif types.is_int32(at):
spark_type = IntegerType()
elif types.is_int64(at):
spark_type = LongType()
elif types.is_float32(at):
spark_type = FloatType()
elif types.is_float64(at):
spark_type = DoubleType()
elif types.is_decimal(at):
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif types.is_string(at):
spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])
def _get_local_timezone():
""" Get local timezone using pytz with environment variable, or dateutil.
If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
it reads system configuration to know the system local timezone.
See also:
- https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
- https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
"""
import os
return os.environ.get('TZ', 'dateutil/:')
def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
for column, series in pdf.iteritems():
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf
def _check_series_convert_timestamps_internal(s, timezone):
"""
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
Spark internal storage
:param s: a pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
# timestamp is during the hour when the clock is adjusted backward during due to
# daylight saving time (dst).
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
#
# Here we explicit choose to use standard time. This matches the default behavior of
# pytz.
#
# Here are some code to help understand this behavior:
# >>> import datetime
# >>> import pandas as pd
# >>> import pytz
# >>>
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
# >>> ts = pd.Series([t])
# >>> tz = pytz.timezone('America/New_York')
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=True)
# 0 2015-11-01 01:30:00-04:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=False)
# 0 2015-11-01 01:30:00-05:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> str(tz.localize(t))
# '2015-11-01 01:30:00-05:00'
tz = timezone or _get_local_timezone()
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s
def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or _get_local_timezone()
to_tz = to_timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
return s.apply(
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT)
else:
return s
def _check_series_convert_timestamps_local_tz(s, timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, None, timezone)
def _check_series_convert_timestamps_tz_local(s, timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param timezone: the timezone to convert from. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)

View file

@ -0,0 +1,60 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pandas_version = "0.23.2"
from distutils.version import LooseVersion
try:
import pandas
have_pandas = True
except ImportError:
have_pandas = False
if not have_pandas:
raise ImportError("Pandas >= %s must be installed; however, "
"it was not found." % minimum_pandas_version)
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
raise ImportError("Pandas >= %s must be installed; however, "
"your version was %s." % (minimum_pandas_version, pandas.__version__))
def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pyarrow_version = "0.15.1"
from distutils.version import LooseVersion
import os
try:
import pyarrow
have_arrow = True
except ImportError:
have_arrow = False
if not have_arrow:
raise ImportError("PyArrow >= %s must be installed; however, "
"it was not found." % minimum_pyarrow_version)
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
raise ImportError("PyArrow >= %s must be installed; however, "
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1":
raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, "
"please unset ARROW_PRE_0_15_IPC_FORMAT")

View file

@ -15,6 +15,8 @@
# limitations under the License.
#
# To disallow implicit relative import. Remove this once we drop Python 2.
from __future__ import absolute_import
from __future__ import print_function
import sys
import warnings
@ -25,15 +27,16 @@ if sys.version >= '3':
basestring = unicode = str
xrange = range
else:
from itertools import izip as zip, imap as map
from itertools import imap as map
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \
from pyspark.sql.types import Row, DataType, StringType, StructType, \
_make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
_parse_datatype_string
from pyspark.sql.utils import install_exception_handler
@ -60,7 +63,7 @@ def _monkey_patch_RDD(sparkSession):
RDD.toDF = toDF
class SparkSession(object):
class SparkSession(SparkConversionMixin):
"""The entry point to programming Spark with the Dataset and DataFrame API.
A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
@ -458,135 +461,6 @@ class SparkSession(object):
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema
def _get_numpy_record_dtype(self, rec):
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
the dtypes of fields in a record so they can be properly loaded into Spark.
:param rec: a numpy record to check field dtypes
:return corrected dtype for a numpy.record or None if no correction needed
"""
import numpy as np
cur_dtypes = rec.dtype
col_names = cur_dtypes.names
record_type_list = []
has_rec_fix = False
for i in xrange(len(cur_dtypes)):
curr_type = cur_dtypes[i]
# If type is a datetime64 timestamp, convert to microseconds
# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,
# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417
if curr_type == np.dtype('datetime64[ns]'):
curr_type = 'datetime64[us]'
has_rec_fix = True
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None
def _convert_from_pandas(self, pdf, schema, timezone):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
:return list of records
"""
if timezone is not None:
from pyspark.sql.types import _check_series_convert_timestamps_tz_local
copied = False
if isinstance(schema, StructType):
for field in schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if s is not pdf[field.name]:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s
else:
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(series, timezone)
if s is not series:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[column] = s
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
# Check if any columns need to be fixed for Spark to infer properly
if len(np_records) > 0:
record_dtype = self._get_numpy_record_dtype(np_records[0])
if record_dtype is not None:
return [r.astype(record_dtype).tolist() for r in np_records]
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records]
def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
import pyarrow as pa
# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
schema = struct
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create list of Arrow (columns, type) for serializer dump_stream
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]
jsqlContext = self._wrapped._jsqlContext
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)
# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
@staticmethod
def _create_shell_session():
"""
@ -722,46 +596,12 @@ class SparkSession(object):
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
if self._wrapped._conf.pandasRespectSessionTimeZone():
timezone = self._wrapped._conf.sessionLocalTimeZone()
else:
timezone = None
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in data.columns]
if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
from pyspark.util import _exception_message
if self._wrapped._conf.arrowPySparkFallbackEnabled():
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
"failed by the reason below:\n %s\n"
"Attempting non-optimization as "
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
"true." % _exception_message(e))
warnings.warn(msg)
else:
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
"reached the error below and will not continue because automatic "
"fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
"has been set to false.\n %s" % _exception_message(e))
warnings.warn(msg)
raise
data = self._convert_from_pandas(data, schema, timezone)
# Create a DataFrame from pandas DataFrame.
return super(SparkSession, self).createDataFrame(
data, schema, verifySchema, samplingRatio)
return self._create_dataframe(data, schema, verifySchema, samplingRatio)
def _create_dataframe(self, data, schema, verifySchema, samplingRatio):
if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True

View file

@ -178,7 +178,7 @@ class ArrowTests(ReusedSQLTestCase):
self.assertFalse(pdf_ny.equals(pdf_la))
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
pdf_la_corrected = pdf_la.copy()
for field in self.schema:
if isinstance(field.dataType, TimestampType):
@ -311,7 +311,7 @@ class ArrowTests(ReusedSQLTestCase):
self.assertTrue(pdf.equals(pdf_copy))
def test_schema_conversion_roundtrip(self):
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
arrow_schema = to_arrow_schema(self.schema)
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)

View file

@ -1608,267 +1608,6 @@ register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
elif type(dt) == ByteType:
arrow_type = pa.int8()
elif type(dt) == ShortType:
arrow_type = pa.int16()
elif type(dt) == IntegerType:
arrow_type = pa.int32()
elif type(dt) == LongType:
arrow_type = pa.int64()
elif type(dt) == FloatType:
arrow_type = pa.float32()
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
spark_type = ShortType()
elif types.is_int32(at):
spark_type = IntegerType()
elif types.is_int64(at):
spark_type = LongType()
elif types.is_float32(at):
spark_type = FloatType()
elif types.is_float64(at):
spark_type = DoubleType()
elif types.is_decimal(at):
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif types.is_string(at):
spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])
def _get_local_timezone():
""" Get local timezone using pytz with environment variable, or dateutil.
If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
it reads system configuration to know the system local timezone.
See also:
- https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
- https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
"""
import os
return os.environ.get('TZ', 'dateutil/:')
def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
for column, series in pdf.iteritems():
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf
def _check_series_convert_timestamps_internal(s, timezone):
"""
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
Spark internal storage
:param s: a pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
# timestamp is during the hour when the clock is adjusted backward during due to
# daylight saving time (dst).
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
#
# Here we explicit choose to use standard time. This matches the default behavior of
# pytz.
#
# Here are some code to help understand this behavior:
# >>> import datetime
# >>> import pandas as pd
# >>> import pytz
# >>>
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
# >>> ts = pd.Series([t])
# >>> tz = pytz.timezone('America/New_York')
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=True)
# 0 2015-11-01 01:30:00-04:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=False)
# 0 2015-11-01 01:30:00-05:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> str(tz.localize(t))
# '2015-11-01 01:30:00-05:00'
tz = timezone or _get_local_timezone()
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s
def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or _get_local_timezone()
to_tz = to_timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
return s.apply(
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT)
else:
return s
def _check_series_convert_timestamps_local_tz(s, timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, None, timezone)
def _check_series_convert_timestamps_tz_local(s, timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param timezone: the timezone to convert from. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)
def _test():
import doctest
from pyspark.context import SparkContext

View file

@ -23,8 +23,8 @@ import sys
from pyspark import SparkContext, since
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
to_arrow_type, to_arrow_schema
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.util import _get_argspec
__all__ = ["UDFRegistration"]
@ -46,7 +46,7 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
from pyspark.sql.utils import require_minimum_pyarrow_version
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
argspec = _get_argspec(f)

View file

@ -136,50 +136,6 @@ def toJArray(gateway, jtype, arr):
return jarr
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pandas_version = "0.23.2"
from distutils.version import LooseVersion
try:
import pandas
have_pandas = True
except ImportError:
have_pandas = False
if not have_pandas:
raise ImportError("Pandas >= %s must be installed; however, "
"it was not found." % minimum_pandas_version)
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
raise ImportError("Pandas >= %s must be installed; however, "
"your version was %s." % (minimum_pandas_version, pandas.__version__))
def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pyarrow_version = "0.15.1"
from distutils.version import LooseVersion
import os
try:
import pyarrow
have_arrow = True
except ImportError:
have_arrow = False
if not have_arrow:
raise ImportError("PyArrow >= %s must be installed; however, "
"it was not found." % minimum_pyarrow_version)
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
raise ImportError("PyArrow >= %s must be installed; however, "
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1":
raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, "
"please unset ARROW_PRE_0_15_IPC_FORMAT")
def require_test_compiled():
""" Raise Exception if test classes are not compiled
"""

View file

@ -29,7 +29,7 @@ from pyspark.util import _exception_message
pandas_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pandas_version
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
# If Pandas version requirement is not satisfied, skip related tests.
@ -37,7 +37,7 @@ except ImportError as e:
pyarrow_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pyarrow_version
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
# If Arrow version requirement is not satisfied, skip related tests.

View file

@ -39,8 +39,10 @@ from pyspark.resourceinformation import ResourceInformation
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
from pyspark.sql.types import to_arrow_type, StructType
BatchedSerializer
from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

View file

@ -179,6 +179,8 @@ try:
'pyspark.ml.linalg',
'pyspark.ml.param',
'pyspark.sql',
'pyspark.sql.avro',
'pyspark.sql.pandas',
'pyspark.streaming',
'pyspark.bin',
'pyspark.sbin',

View file

@ -105,7 +105,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
Seq(
pythonExec,
"-c",
"from pyspark.sql.utils import require_minimum_pandas_version;" +
"from pyspark.sql.pandas.utils import require_minimum_pandas_version;" +
"require_minimum_pandas_version()"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
@ -117,7 +117,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
Seq(
pythonExec,
"-c",
"from pyspark.sql.utils import require_minimum_pyarrow_version;" +
"from pyspark.sql.pandas.utils import require_minimum_pyarrow_version;" +
"require_minimum_pyarrow_version()"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!