[SPARK-30434][PYTHON][SQL] Move pandas related functionalities into 'pandas' sub-package
### What changes were proposed in this pull request? This PR proposes to move pandas related functionalities into pandas package. Namely: ```bash pyspark/sql/pandas ├── __init__.py ├── conversion.py # Conversion between pandas <> PySpark DataFrames ├── functions.py # pandas_udf ├── group_ops.py # Grouped UDF / Cogrouped UDF + groupby.apply, groupby.cogroup.apply ├── map_ops.py # Map Iter UDF + mapInPandas ├── serializers.py # pandas <> PyArrow serializers ├── types.py # Type utils between pandas <> PyArrow └── utils.py # Version requirement checks ``` In order to separately locate `groupby.apply`, `groupby.cogroup.apply`, `mapInPandas`, `toPandas`, and `createDataFrame(pdf)` under `pandas` sub-package, I had to use a mix-in approach which Scala side uses often by `trait`, and also pandas itself uses this approach (see `IndexOpsMixin` as an example) to group related functionalities. Currently, you can think it's like Scala's self typed trait. See the structure below: ```python class PandasMapOpsMixin(object): def mapInPandas(self, ...): ... return ... # other Pandas <> PySpark APIs ``` ```python class DataFrame(PandasMapOpsMixin): # other DataFrame APIs equivalent to Scala side. ``` Yes, This is a big PR but they are mostly just moving around except one case `createDataFrame` which I had to split the methods. ### Why are the changes needed? There are pandas functionalities here and there and I myself gets lost where it was. Also, when you have to make a change commonly for all of pandas related features, it's almost impossible now. Also, after this change, `DataFrame` and `SparkSession` become more consistent with Scala side since pandas is specific to Python, and this change separates pandas-specific APIs away from `DataFrame` or `SparkSession`. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests should cover. Also, I manually built the PySpark API documentation and checked. Closes #27109 from HyukjinKwon/pandas-refactoring. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
18daa37cdb
commit
ee8d661058
|
@ -362,6 +362,13 @@ pyspark_sql = Module(
|
|||
"pyspark.sql.udf",
|
||||
"pyspark.sql.window",
|
||||
"pyspark.sql.avro.functions",
|
||||
"pyspark.sql.pandas.conversion",
|
||||
"pyspark.sql.pandas.map_ops",
|
||||
"pyspark.sql.pandas.functions",
|
||||
"pyspark.sql.pandas.group_ops",
|
||||
"pyspark.sql.pandas.types",
|
||||
"pyspark.sql.pandas.serializers",
|
||||
"pyspark.sql.pandas.utils",
|
||||
# unittests
|
||||
"pyspark.sql.tests.test_arrow",
|
||||
"pyspark.sql.tests.test_catalog",
|
||||
|
|
|
@ -24,7 +24,7 @@ Run with:
|
|||
from __future__ import print_function
|
||||
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pandas_version()
|
||||
require_minimum_pyarrow_version()
|
||||
|
|
|
@ -7,6 +7,7 @@ Module Context
|
|||
.. automodule:: pyspark.sql
|
||||
:members:
|
||||
:undoc-members:
|
||||
:inherited-members:
|
||||
:exclude-members: builder
|
||||
.. We need `exclude-members` to prevent default description generations
|
||||
as a workaround for old Sphinx (< 1.6.6).
|
||||
|
|
|
@ -185,248 +185,6 @@ class FramedSerializer(Serializer):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class ArrowCollectSerializer(Serializer):
|
||||
"""
|
||||
Deserialize a stream of batches followed by batch order information. Used in
|
||||
DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.serializer = ArrowStreamSerializer()
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
return self.serializer.dump_stream(iterator, stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
|
||||
a list of indices that can be used to put the RecordBatches in the correct order.
|
||||
"""
|
||||
# load the batches
|
||||
for batch in self.serializer.load_stream(stream):
|
||||
yield batch
|
||||
|
||||
# load the batch order indices or propagate any error that occurred in the JVM
|
||||
num = read_int(stream)
|
||||
if num == -1:
|
||||
error_msg = UTF8Deserializer().loads(stream)
|
||||
raise RuntimeError("An error occurred while calling "
|
||||
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
|
||||
batch_order = []
|
||||
for i in xrange(num):
|
||||
index = read_int(stream)
|
||||
batch_order.append(index)
|
||||
yield batch_order
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowCollectSerializer(%s)" % self.serializer
|
||||
|
||||
|
||||
class ArrowStreamSerializer(Serializer):
|
||||
"""
|
||||
Serializes Arrow record batches as a stream.
|
||||
"""
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
import pyarrow as pa
|
||||
writer = None
|
||||
try:
|
||||
for batch in iterator:
|
||||
if writer is None:
|
||||
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
|
||||
writer.write_batch(batch)
|
||||
finally:
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
def load_stream(self, stream):
|
||||
import pyarrow as pa
|
||||
reader = pa.ipc.open_stream(stream)
|
||||
for batch in reader:
|
||||
yield batch
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamSerializer"
|
||||
|
||||
|
||||
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
||||
"""
|
||||
Serializes Pandas.Series as Arrow data with Arrow streaming format.
|
||||
|
||||
:param timezone: A timezone to respect when handling timestamp values
|
||||
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
|
||||
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
|
||||
"""
|
||||
|
||||
def __init__(self, timezone, safecheck, assign_cols_by_name):
|
||||
super(ArrowStreamPandasSerializer, self).__init__()
|
||||
self._timezone = timezone
|
||||
self._safecheck = safecheck
|
||||
self._assign_cols_by_name = assign_cols_by_name
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
from pyspark.sql.types import _check_series_localize_timestamps
|
||||
|
||||
# If the given column is a date type column, creates a series of datetime.date directly
|
||||
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
|
||||
# datetime64[ns] type handling.
|
||||
s = arrow_column.to_pandas(date_as_object=True)
|
||||
|
||||
s = _check_series_localize_timestamps(s, self._timezone)
|
||||
return s
|
||||
|
||||
def _create_batch(self, series):
|
||||
"""
|
||||
Create an Arrow record batch from the given pandas.Series or list of Series,
|
||||
with optional type.
|
||||
|
||||
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
|
||||
:return: Arrow RecordBatch
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pyspark.sql.types import _check_series_convert_timestamps_internal
|
||||
# Make input conform to [(series1, type1), (series2, type2), ...]
|
||||
if not isinstance(series, (list, tuple)) or \
|
||||
(len(series) == 2 and isinstance(series[1], pa.DataType)):
|
||||
series = [series]
|
||||
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
|
||||
|
||||
def create_array(s, t):
|
||||
mask = s.isnull()
|
||||
# Ensure timestamp series are in expected form for Spark internal representation
|
||||
if t is not None and pa.types.is_timestamp(t):
|
||||
s = _check_series_convert_timestamps_internal(s, self._timezone)
|
||||
try:
|
||||
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
|
||||
except pa.ArrowException as e:
|
||||
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
|
||||
"Array (%s). It can be caused by overflows or other unsafe " + \
|
||||
"conversions warned by Arrow. Arrow safe type check can be " + \
|
||||
"disabled by using SQL config " + \
|
||||
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
|
||||
raise RuntimeError(error_msg % (s.dtype, t), e)
|
||||
return array
|
||||
|
||||
arrs = []
|
||||
for s, t in series:
|
||||
if t is not None and pa.types.is_struct(t):
|
||||
if not isinstance(s, pd.DataFrame):
|
||||
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
|
||||
"but got: %s" % str(type(s)))
|
||||
|
||||
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
|
||||
if len(s) == 0 and len(s.columns) == 0:
|
||||
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
|
||||
# Assign result columns by schema name if user labeled with strings
|
||||
elif self._assign_cols_by_name and any(isinstance(name, basestring)
|
||||
for name in s.columns):
|
||||
arrs_names = [(create_array(s[field.name], field.type), field.name)
|
||||
for field in t]
|
||||
# Assign result columns by position
|
||||
else:
|
||||
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
|
||||
for i, field in enumerate(t)]
|
||||
|
||||
struct_arrs, struct_names = zip(*arrs_names)
|
||||
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
|
||||
else:
|
||||
arrs.append(create_array(s, t))
|
||||
|
||||
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
|
||||
a list of series accompanied by an optional pyarrow type to coerce the data to.
|
||||
"""
|
||||
batches = (self._create_batch(series) for series in iterator)
|
||||
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
||||
"""
|
||||
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
|
||||
import pyarrow as pa
|
||||
for batch in batches:
|
||||
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamPandasSerializer"
|
||||
|
||||
|
||||
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
|
||||
"""
|
||||
Serializer used by Python worker to evaluate Pandas UDFs
|
||||
"""
|
||||
|
||||
def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
|
||||
super(ArrowStreamPandasUDFSerializer, self) \
|
||||
.__init__(timezone, safecheck, assign_cols_by_name)
|
||||
self._df_for_struct = df_for_struct
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
import pyarrow.types as types
|
||||
|
||||
if self._df_for_struct and types.is_struct(arrow_column.type):
|
||||
import pandas as pd
|
||||
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
|
||||
.rename(field.name)
|
||||
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
|
||||
s = pd.concat(series, axis=1)
|
||||
else:
|
||||
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
|
||||
return s
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
|
||||
This should be sent after creating the first record batch so in case of an error, it can
|
||||
be sent back to the JVM before the Arrow stream starts.
|
||||
"""
|
||||
|
||||
def init_stream_yield_batches():
|
||||
should_write_start_length = True
|
||||
for series in iterator:
|
||||
batch = self._create_batch(series)
|
||||
if should_write_start_length:
|
||||
write_int(SpecialLengths.START_ARROW_STREAM, stream)
|
||||
should_write_start_length = False
|
||||
yield batch
|
||||
|
||||
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamPandasUDFSerializer"
|
||||
|
||||
|
||||
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
|
||||
lists of pandas.Series.
|
||||
"""
|
||||
import pyarrow as pa
|
||||
dataframes_in_group = None
|
||||
|
||||
while dataframes_in_group is None or dataframes_in_group > 0:
|
||||
dataframes_in_group = read_int(stream)
|
||||
|
||||
if dataframes_in_group == 2:
|
||||
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
yield (
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
|
||||
)
|
||||
|
||||
elif dataframes_in_group != 0:
|
||||
raise ValueError(
|
||||
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
|
||||
|
||||
|
||||
class BatchedSerializer(Serializer):
|
||||
|
||||
"""
|
||||
|
|
|
@ -51,12 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat
|
|||
from pyspark.sql.group import GroupedData
|
||||
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
|
||||
from pyspark.sql.window import Window, WindowSpec
|
||||
from pyspark.sql.cogroup import CoGroupedData
|
||||
from pyspark.sql.pandas.group_ops import PandasCogroupedOps
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SparkSession', 'SQLContext', 'UDFRegistration',
|
||||
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
|
||||
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
|
||||
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
|
||||
'DataFrameReader', 'DataFrameWriter', 'PandasCogroupedOps'
|
||||
]
|
||||
|
|
|
@ -31,8 +31,8 @@ import warnings
|
|||
|
||||
from pyspark import copy_func, since, _NoValue
|
||||
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \
|
||||
ignore_unicode_prefix, PythonEvalType
|
||||
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
|
||||
ignore_unicode_prefix
|
||||
from pyspark.serializers import BatchedSerializer, PickleSerializer, \
|
||||
UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
|
@ -40,14 +40,14 @@ from pyspark.sql.types import _parse_datatype_json_string
|
|||
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
|
||||
from pyspark.sql.readwriter import DataFrameWriter
|
||||
from pyspark.sql.streaming import DataStreamWriter
|
||||
from pyspark.sql.types import IntegralType
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.util import _exception_message
|
||||
from pyspark.sql.pandas.conversion import PandasConversionMixin
|
||||
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
|
||||
|
||||
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
|
||||
|
||||
|
||||
class DataFrame(object):
|
||||
class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
|
||||
"""A distributed collection of data grouped into named columns.
|
||||
|
||||
A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
|
||||
|
@ -2135,193 +2135,6 @@ class DataFrame(object):
|
|||
"should have been DataFrame." % type(result)
|
||||
return result
|
||||
|
||||
@since(1.3)
|
||||
def toPandas(self):
|
||||
"""
|
||||
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
|
||||
|
||||
This is only available if Pandas is installed and available.
|
||||
|
||||
.. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
|
||||
expected to be small, as all the data is loaded into the driver's memory.
|
||||
|
||||
.. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
|
||||
|
||||
>>> df.toPandas() # doctest: +SKIP
|
||||
age name
|
||||
0 2 Alice
|
||||
1 5 Bob
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if self.sql_ctx._conf.pandasRespectSessionTimeZone():
|
||||
timezone = self.sql_ctx._conf.sessionLocalTimeZone()
|
||||
else:
|
||||
timezone = None
|
||||
|
||||
if self.sql_ctx._conf.arrowPySparkEnabled():
|
||||
use_arrow = True
|
||||
try:
|
||||
from pyspark.sql.types import to_arrow_schema
|
||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pyarrow_version()
|
||||
to_arrow_schema(self.schema)
|
||||
except Exception as e:
|
||||
|
||||
if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
|
||||
"failed by the reason below:\n %s\n"
|
||||
"Attempting non-optimization as "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
|
||||
"true." % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
use_arrow = False
|
||||
else:
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and will not continue because automatic fallback "
|
||||
"with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
|
||||
"false.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
|
||||
# Try to use Arrow optimization when the schema is supported and the required version
|
||||
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
|
||||
if use_arrow:
|
||||
try:
|
||||
from pyspark.sql.types import _check_dataframe_localize_timestamps
|
||||
import pyarrow
|
||||
batches = self._collectAsArrow()
|
||||
if len(batches) > 0:
|
||||
table = pyarrow.Table.from_batches(batches)
|
||||
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
|
||||
# values, but we should use datetime.date to match the behavior with when
|
||||
# Arrow optimization is disabled.
|
||||
pdf = table.to_pandas(date_as_object=True)
|
||||
return _check_dataframe_localize_timestamps(pdf, timezone)
|
||||
else:
|
||||
return pd.DataFrame.from_records([], columns=self.columns)
|
||||
except Exception as e:
|
||||
# We might have to allow fallback here as well but multiple Spark jobs can
|
||||
# be executed. So, simply fail in this case for now.
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and can not continue. Note that "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
|
||||
"effect on failures in the middle of "
|
||||
"computation.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
|
||||
# Below is toPandas without Arrow optimization.
|
||||
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
|
||||
|
||||
dtype = {}
|
||||
for field in self.schema:
|
||||
pandas_type = _to_corrected_pandas_type(field.dataType)
|
||||
# SPARK-21766: if an integer field is nullable and has null values, it can be
|
||||
# inferred by pandas as float column. Once we convert the column with NaN back
|
||||
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
|
||||
# float type, not the corrected type from the schema in this case.
|
||||
if pandas_type is not None and \
|
||||
not(isinstance(field.dataType, IntegralType) and field.nullable and
|
||||
pdf[field.name].isnull().any()):
|
||||
dtype[field.name] = pandas_type
|
||||
# Ensure we fall back to nullable numpy types, even when whole column is null:
|
||||
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
|
||||
dtype[field.name] = np.float64
|
||||
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
|
||||
dtype[field.name] = np.object
|
||||
|
||||
for f, t in dtype.items():
|
||||
pdf[f] = pdf[f].astype(t, copy=False)
|
||||
|
||||
if timezone is None:
|
||||
return pdf
|
||||
else:
|
||||
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
|
||||
for field in self.schema:
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if isinstance(field.dataType, TimestampType):
|
||||
pdf[field.name] = \
|
||||
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
|
||||
return pdf
|
||||
|
||||
def mapInPandas(self, udf):
|
||||
"""
|
||||
Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
|
||||
function and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return
|
||||
another iterator of `pandas.DataFrame`\\s. All columns are passed
|
||||
together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the
|
||||
returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
|
||||
Each `pandas.DataFrame` size can be controlled by
|
||||
`spark.sql.execution.arrow.maxRecordsPerBatch`.
|
||||
Its schema must match the returnType of the Pandas user-defined function.
|
||||
|
||||
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
|
||||
... ("id", "age")) # doctest: +SKIP
|
||||
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
|
||||
... def filter_func(batch_iter):
|
||||
... for pdf in batch_iter:
|
||||
... yield pdf[pdf.id == 1]
|
||||
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id|age|
|
||||
+---+---+
|
||||
| 1| 21|
|
||||
+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"MAP_ITER.")
|
||||
|
||||
udf_column = udf(*[self[col] for col in self.columns])
|
||||
jdf = self._jdf.mapInPandas(udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
def _collectAsArrow(self):
|
||||
"""
|
||||
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
|
||||
and available on driver and worker Python environments.
|
||||
|
||||
.. note:: Experimental.
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
|
||||
|
||||
# Collect list of un-ordered batches where last element is a list of correct order indices
|
||||
try:
|
||||
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
|
||||
finally:
|
||||
# Join serving thread and raise any exceptions from collectAsArrowToPython
|
||||
jsocket_auth_server.getResult()
|
||||
|
||||
# Separate RecordBatches from batch order indices in results
|
||||
batches = results[:-1]
|
||||
batch_order = results[-1]
|
||||
|
||||
# Re-order the batch list using the correct order
|
||||
return [batches[i] for i in batch_order]
|
||||
|
||||
##########################################################################################
|
||||
# Pandas compatibility
|
||||
##########################################################################################
|
||||
|
@ -2349,33 +2162,6 @@ def _to_scala_map(sc, jm):
|
|||
return sc._jvm.PythonUtils.toScalaMap(jm)
|
||||
|
||||
|
||||
def _to_corrected_pandas_type(dt):
|
||||
"""
|
||||
When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type may be
|
||||
wrong. This method gets the corrected data type for Pandas if that type may be inferred
|
||||
uncorrectly.
|
||||
"""
|
||||
import numpy as np
|
||||
if type(dt) == ByteType:
|
||||
return np.int8
|
||||
elif type(dt) == ShortType:
|
||||
return np.int16
|
||||
elif type(dt) == IntegerType:
|
||||
return np.int32
|
||||
elif type(dt) == LongType:
|
||||
return np.int64
|
||||
elif type(dt) == FloatType:
|
||||
return np.float32
|
||||
elif type(dt) == DoubleType:
|
||||
return np.float64
|
||||
elif type(dt) == BooleanType:
|
||||
return np.bool
|
||||
elif type(dt) == TimestampType:
|
||||
return np.datetime64
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class DataFrameNaFunctions(object):
|
||||
"""Functionality for working with missing data in :class:`DataFrame`.
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ from pyspark.sql.dataframe import DataFrame
|
|||
from pyspark.sql.types import StringType, DataType
|
||||
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
|
||||
from pyspark.sql.udf import UserDefinedFunction, _create_udf
|
||||
# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
|
||||
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
|
||||
from pyspark.sql.utils import to_str
|
||||
|
||||
# Note to developers: all of PySpark functions here take string as column names whenever possible.
|
||||
|
@ -2814,22 +2816,6 @@ def from_csv(col, schema, options={}):
|
|||
|
||||
# ---------------------------- User Defined Function ----------------------------------
|
||||
|
||||
class PandasUDFType(object):
|
||||
"""Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
|
||||
"""
|
||||
SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
|
||||
SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
|
||||
|
||||
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
|
||||
|
||||
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
|
||||
|
||||
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
||||
|
||||
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
||||
|
||||
|
||||
@since(1.3)
|
||||
def udf(f=None, returnType=StringType()):
|
||||
"""Creates a user defined function (UDF).
|
||||
|
@ -2917,483 +2903,6 @@ def udf(f=None, returnType=StringType()):
|
|||
evalType=PythonEvalType.SQL_BATCHED_UDF)
|
||||
|
||||
|
||||
@since(2.3)
|
||||
def pandas_udf(f=None, returnType=None, functionType=None):
|
||||
"""
|
||||
Creates a vectorized user defined function (UDF).
|
||||
|
||||
:param f: user-defined function. A python function if used as a standalone function
|
||||
:param returnType: the return type of the user-defined function. The value can be either a
|
||||
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
||||
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
|
||||
Default: SCALAR.
|
||||
|
||||
The function type of the UDF can be one of the following:
|
||||
|
||||
1. SCALAR
|
||||
|
||||
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
|
||||
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
|
||||
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
|
||||
|
||||
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
|
||||
|
||||
Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and
|
||||
:meth:`pyspark.sql.DataFrame.select`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> from pyspark.sql.types import IntegerType, StringType
|
||||
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP
|
||||
>>> @pandas_udf(StringType()) # doctest: +SKIP
|
||||
... def to_upper(s):
|
||||
... return s.str.upper()
|
||||
...
|
||||
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def add_one(x):
|
||||
... return x + 1
|
||||
...
|
||||
>>> df = spark.createDataFrame([(1, "John Doe", 21)],
|
||||
... ("id", "name", "age")) # doctest: +SKIP
|
||||
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
|
||||
... .show() # doctest: +SKIP
|
||||
+----------+--------------+------------+
|
||||
|slen(name)|to_upper(name)|add_one(age)|
|
||||
+----------+--------------+------------+
|
||||
| 8| JOHN DOE| 22|
|
||||
+----------+--------------+------------+
|
||||
>>> @pandas_udf("first string, last string") # doctest: +SKIP
|
||||
... def split_expand(n):
|
||||
... return n.str.split(expand=True)
|
||||
>>> df.select(split_expand("name")).show() # doctest: +SKIP
|
||||
+------------------+
|
||||
|split_expand(name)|
|
||||
+------------------+
|
||||
| [John, Doe]|
|
||||
+------------------+
|
||||
|
||||
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
|
||||
column, but is the length of an internal batch used for each call to the function.
|
||||
Therefore, this can be used, for example, to ensure the length of each returned
|
||||
`pandas.Series`, and can not be used as the column length.
|
||||
|
||||
2. SCALAR_ITER
|
||||
|
||||
A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the
|
||||
wrapped Python function takes an iterator of batches as input instead of a single batch and,
|
||||
instead of returning a single output batch, it yields output batches or explicitly returns an
|
||||
generator or an iterator of output batches.
|
||||
It is useful when the UDF execution requires initializing some state, e.g., loading a machine
|
||||
learning model file to apply inference to every input batch.
|
||||
|
||||
.. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all
|
||||
batches from one partition, although it is currently implemented this way.
|
||||
Your code shall not rely on this behavior because it might change in the future for
|
||||
further optimization, e.g., one invocation processes multiple partitions.
|
||||
|
||||
Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
|
||||
:meth:`pyspark.sql.DataFrame.select`.
|
||||
|
||||
>>> import pandas as pd # doctest: +SKIP
|
||||
>>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType
|
||||
>>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP
|
||||
>>> df = spark.createDataFrame(pdf) # doctest: +SKIP
|
||||
|
||||
When the UDF is called with a single column that is not `StructType`, the input to the
|
||||
underlying function is an iterator of `pd.Series`.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def plus_one(batch_iter):
|
||||
... for x in batch_iter:
|
||||
... yield x + 1
|
||||
...
|
||||
>>> df.select(plus_one(col("x"))).show() # doctest: +SKIP
|
||||
+-----------+
|
||||
|plus_one(x)|
|
||||
+-----------+
|
||||
| 2|
|
||||
| 3|
|
||||
| 4|
|
||||
+-----------+
|
||||
|
||||
When the UDF is called with more than one columns, the input to the underlying function is an
|
||||
iterator of `pd.Series` tuple.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def multiply_two_cols(batch_iter):
|
||||
... for a, b in batch_iter:
|
||||
... yield a * b
|
||||
...
|
||||
>>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP
|
||||
+-----------------------+
|
||||
|multiply_two_cols(x, x)|
|
||||
+-----------------------+
|
||||
| 1|
|
||||
| 4|
|
||||
| 9|
|
||||
+-----------------------+
|
||||
|
||||
When the UDF is called with a single column that is `StructType`, the input to the underlying
|
||||
function is an iterator of `pd.DataFrame`.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def multiply_two_nested_cols(pdf_iter):
|
||||
... for pdf in pdf_iter:
|
||||
... yield pdf["a"] * pdf["b"]
|
||||
...
|
||||
>>> df.select(
|
||||
... multiply_two_nested_cols(
|
||||
... struct(col("x").alias("a"), col("x").alias("b"))
|
||||
... ).alias("y")
|
||||
... ).show() # doctest: +SKIP
|
||||
+---+
|
||||
| y|
|
||||
+---+
|
||||
| 1|
|
||||
| 4|
|
||||
| 9|
|
||||
+---+
|
||||
|
||||
In the UDF, you can initialize some states before processing batches, wrap your code with
|
||||
`try ... finally ...` or use context managers to ensure the release of resources at the end
|
||||
or in case of early termination.
|
||||
|
||||
>>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def plus_y(batch_iter):
|
||||
... y = y_bc.value # initialize some state
|
||||
... try:
|
||||
... for x in batch_iter:
|
||||
... yield x + y
|
||||
... finally:
|
||||
... pass # release resources here, if any
|
||||
...
|
||||
>>> df.select(plus_y(col("x"))).show() # doctest: +SKIP
|
||||
+---------+
|
||||
|plus_y(x)|
|
||||
+---------+
|
||||
| 2|
|
||||
| 3|
|
||||
| 4|
|
||||
+---------+
|
||||
|
||||
3. GROUPED_MAP
|
||||
|
||||
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
|
||||
The returnType should be a :class:`StructType` describing the schema of the returned
|
||||
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
|
||||
the field names in the defined returnType schema if specified as strings, or match the
|
||||
field data types by position if not strings, e.g. integer indices.
|
||||
The length of the returned `pandas.DataFrame` can be arbitrary.
|
||||
|
||||
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v")) # doctest: +SKIP
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def normalize(pdf):
|
||||
... v = pdf.v
|
||||
... return pdf.assign(v=(v - v.mean()) / v.std())
|
||||
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
|
||||
+---+-------------------+
|
||||
| id| v|
|
||||
+---+-------------------+
|
||||
| 1|-0.7071067811865475|
|
||||
| 1| 0.7071067811865475|
|
||||
| 2|-0.8320502943378437|
|
||||
| 2|-0.2773500981126146|
|
||||
| 2| 1.1094003924504583|
|
||||
+---+-------------------+
|
||||
|
||||
Alternatively, the user can define a function that takes two arguments.
|
||||
In this case, the grouping key(s) will be passed as the first argument and the data will
|
||||
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
|
||||
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
|
||||
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
|
||||
This is useful when the user does not want to hardcode grouping key(s) in the function.
|
||||
|
||||
>>> import pandas as pd # doctest: +SKIP
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v")) # doctest: +SKIP
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def mean_udf(key, pdf):
|
||||
... # key is a tuple of one numpy.int64, which is the value
|
||||
... # of 'id' for the current group
|
||||
... return pd.DataFrame([key + (pdf.v.mean(),)])
|
||||
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id| v|
|
||||
+---+---+
|
||||
| 1|1.5|
|
||||
| 2|6.0|
|
||||
+---+---+
|
||||
>>> @pandas_udf(
|
||||
... "id long, `ceil(v / 2)` long, v double",
|
||||
... PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
>>> def sum_udf(key, pdf):
|
||||
... # key is a tuple of two numpy.int64s, which is the values
|
||||
... # of 'id' and 'ceil(df.v / 2)' for the current group
|
||||
... return pd.DataFrame([key + (pdf.v.sum(),)])
|
||||
>>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP
|
||||
+---+-----------+----+
|
||||
| id|ceil(v / 2)| v|
|
||||
+---+-----------+----+
|
||||
| 2| 5|10.0|
|
||||
| 1| 1| 3.0|
|
||||
| 2| 3| 5.0|
|
||||
| 2| 2| 3.0|
|
||||
+---+-----------+----+
|
||||
|
||||
.. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
|
||||
recommended to explicitly index the columns by name to ensure the positions are correct,
|
||||
or alternatively use an `OrderedDict`.
|
||||
For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
|
||||
`pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
|
||||
|
||||
4. GROUPED_AGG
|
||||
|
||||
A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
|
||||
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
|
||||
The returned scalar can be either a python primitive type, e.g., `int` or `float`
|
||||
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
|
||||
|
||||
:class:`MapType` and :class:`StructType` are currently not supported as output types.
|
||||
|
||||
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
|
||||
:class:`pyspark.sql.Window`
|
||||
|
||||
This example shows using grouped aggregated UDFs with groupby:
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def mean_udf(v):
|
||||
... return v.mean()
|
||||
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
|
||||
+---+-----------+
|
||||
| id|mean_udf(v)|
|
||||
+---+-----------+
|
||||
| 1| 1.5|
|
||||
| 2| 6.0|
|
||||
+---+-----------+
|
||||
|
||||
This example shows using grouped aggregated UDFs as window functions.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> from pyspark.sql import Window
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def mean_udf(v):
|
||||
... return v.mean()
|
||||
>>> w = (Window.partitionBy('id')
|
||||
... .orderBy('v')
|
||||
... .rowsBetween(-1, 0))
|
||||
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
|
||||
+---+----+------+
|
||||
| id| v|mean_v|
|
||||
+---+----+------+
|
||||
| 1| 1.0| 1.0|
|
||||
| 1| 2.0| 1.5|
|
||||
| 2| 3.0| 3.0|
|
||||
| 2| 5.0| 4.0|
|
||||
| 2|10.0| 7.5|
|
||||
+---+----+------+
|
||||
|
||||
.. note:: For performance reasons, the input series to window functions are not copied.
|
||||
Therefore, mutating the input series is not allowed and will cause incorrect results.
|
||||
For the same reason, users should also not rely on the index of the input series.
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
|
||||
|
||||
5. MAP_ITER
|
||||
|
||||
A map iterator Pandas UDFs are used to transform data with an iterator of batches.
|
||||
It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`.
|
||||
|
||||
It can return the output of arbitrary length in contrast to the scalar Pandas UDF.
|
||||
It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
|
||||
function and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another
|
||||
iterator of `pandas.DataFrame`\\s. All columns are passed together as an
|
||||
iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of
|
||||
`pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
|
||||
|
||||
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
|
||||
... ("id", "age")) # doctest: +SKIP
|
||||
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
|
||||
... def filter_func(batch_iter):
|
||||
... for pdf in batch_iter:
|
||||
... yield pdf[pdf.id == 1]
|
||||
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id|age|
|
||||
+---+---+
|
||||
| 1| 21|
|
||||
+---+---+
|
||||
|
||||
6. COGROUPED_MAP
|
||||
|
||||
A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) ->
|
||||
`pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema
|
||||
of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame`
|
||||
must either match the field names in the defined `returnType` schema if specified as strings,
|
||||
or match the field data types by position if not strings, e.g. integer indices. The length
|
||||
of the returned `pandas.DataFrame` can be arbitrary.
|
||||
|
||||
CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df1 = spark.createDataFrame(
|
||||
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||
... ("time", "id", "v1"))
|
||||
>>> df2 = spark.createDataFrame(
|
||||
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||
... ("time", "id", "v2"))
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(l, r):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+---------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+---------+---+---+---+
|
||||
| 20000101| 1|1.0| x|
|
||||
| 20000102| 1|3.0| x|
|
||||
| 20000101| 2|2.0| y|
|
||||
| 20000102| 2|4.0| y|
|
||||
+---------+---+---+---+
|
||||
|
||||
Alternatively, the user can define a function that takes three arguments. In this case,
|
||||
the grouping key(s) will be passed as the first argument and the data will be passed as the
|
||||
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
|
||||
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
|
||||
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(k, l, r):
|
||||
... if k == (1,):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
... else:
|
||||
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+---------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+---------+---+---+---+
|
||||
| 20000101| 1|1.0| x|
|
||||
| 20000102| 1|3.0| x|
|
||||
+---------+---+---+---+
|
||||
|
||||
.. note:: The user-defined functions are considered deterministic by default. Due to
|
||||
optimization, duplicate invocations may be eliminated or the function may even be invoked
|
||||
more times than it is present in the query. If your function is not deterministic, call
|
||||
`asNondeterministic` on the user defined function. E.g.:
|
||||
|
||||
>>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def random(v):
|
||||
... import numpy as np
|
||||
... import pandas as pd
|
||||
... return pd.Series(np.random.randn(len(v))
|
||||
>>> random = random.asNondeterministic() # doctest: +SKIP
|
||||
|
||||
.. note:: The user-defined functions do not support conditional expressions or short circuiting
|
||||
in boolean expressions and it ends up with being executed all internally. If the functions
|
||||
can fail on special rows, the workaround is to incorporate the condition into the functions.
|
||||
|
||||
.. note:: The user-defined functions do not take keyword arguments on the calling side.
|
||||
|
||||
.. note:: The data type of returned `pandas.Series` from the user-defined functions should be
|
||||
matched with defined returnType (see :meth:`types.to_arrow_type` and
|
||||
:meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do
|
||||
conversion on returned data. The conversion is not guaranteed to be correct and results
|
||||
should be checked for accuracy by users.
|
||||
"""
|
||||
|
||||
# The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that
|
||||
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
|
||||
# future. The table might have to be eventually documented externally.
|
||||
# Please see SPARK-28132's PR to see the codes in order to generate the table below.
|
||||
#
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
# |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
# | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa
|
||||
# | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa
|
||||
# | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa
|
||||
# | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa
|
||||
# | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa
|
||||
# | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa
|
||||
# | map<string,int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
#
|
||||
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
|
||||
# used in `returnType`.
|
||||
# Note: The values inside of the table are generated by `repr`.
|
||||
# Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used.
|
||||
# Note: Timezone is KST.
|
||||
# Note: 'X' means it throws an exception during the conversion.
|
||||
|
||||
# decorator @pandas_udf(returnType, functionType)
|
||||
is_decorator = f is None or isinstance(f, (str, DataType))
|
||||
|
||||
if is_decorator:
|
||||
# If DataType has been passed as a positional argument
|
||||
# for decorator use it as a returnType
|
||||
return_type = f or returnType
|
||||
|
||||
if functionType is not None:
|
||||
# @pandas_udf(dataType, functionType=functionType)
|
||||
# @pandas_udf(returnType=dataType, functionType=functionType)
|
||||
eval_type = functionType
|
||||
elif returnType is not None and isinstance(returnType, int):
|
||||
# @pandas_udf(dataType, functionType)
|
||||
eval_type = returnType
|
||||
else:
|
||||
# @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
|
||||
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
else:
|
||||
return_type = returnType
|
||||
|
||||
if functionType is not None:
|
||||
eval_type = functionType
|
||||
else:
|
||||
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
|
||||
if return_type is None:
|
||||
raise ValueError("Invalid returnType: returnType can not be None")
|
||||
|
||||
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
|
||||
raise ValueError("Invalid functionType: "
|
||||
"functionType must be one the values from PandasUDFType")
|
||||
|
||||
if is_decorator:
|
||||
return functools.partial(_create_udf, returnType=return_type, evalType=eval_type)
|
||||
else:
|
||||
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
|
||||
|
||||
|
||||
blacklist = ['map', 'since', 'ignore_unicode_prefix']
|
||||
__all__ = [k for k, v in globals().items()
|
||||
if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]
|
||||
|
|
|
@ -18,11 +18,11 @@
|
|||
import sys
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.sql.column import Column, _to_seq
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.sql.cogroup import CoGroupedData
|
||||
|
||||
__all__ = ["GroupedData"]
|
||||
|
||||
|
@ -47,7 +47,7 @@ def df_varargs_api(f):
|
|||
return _api
|
||||
|
||||
|
||||
class GroupedData(object):
|
||||
class GroupedData(PandasGroupedOpsMixin):
|
||||
"""
|
||||
A set of methods for aggregations on a :class:`DataFrame`,
|
||||
created by :func:`DataFrame.groupBy`.
|
||||
|
@ -219,68 +219,6 @@ class GroupedData(object):
|
|||
jgd = self._jgd.pivot(pivot_col, values)
|
||||
return GroupedData(jgd, self._df)
|
||||
|
||||
@since(3.0)
|
||||
def cogroup(self, other):
|
||||
"""
|
||||
Cogroups this group with another group so that we can run cogrouped operations.
|
||||
|
||||
See :class:`CoGroupedData` for the operations that can be run.
|
||||
"""
|
||||
return CoGroupedData(self, other)
|
||||
|
||||
@since(2.3)
|
||||
def apply(self, udf):
|
||||
"""
|
||||
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
|
||||
as a `DataFrame`.
|
||||
|
||||
The user-defined function should take a `pandas.DataFrame` and return another
|
||||
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
|
||||
to the user-function and the returned `pandas.DataFrame` are combined as a
|
||||
:class:`DataFrame`.
|
||||
|
||||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||
returnType of the pandas udf.
|
||||
|
||||
.. note:: This function requires a full shuffle. All the data of a group will be loaded
|
||||
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||
and certain groups are too large to fit in memory.
|
||||
|
||||
:param udf: a grouped map user-defined function returned by
|
||||
:func:`pyspark.sql.functions.pandas_udf`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def normalize(pdf):
|
||||
... v = pdf.v
|
||||
... return pdf.assign(v=(v - v.mean()) / v.std())
|
||||
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
|
||||
+---+-------------------+
|
||||
| id| v|
|
||||
+---+-------------------+
|
||||
| 1|-0.7071067811865475|
|
||||
| 1| 0.7071067811865475|
|
||||
| 2|-0.8320502943378437|
|
||||
| 2|-0.2773500981126146|
|
||||
| 2| 1.1094003924504583|
|
||||
+---+-------------------+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"GROUPED_MAP.")
|
||||
df = self._df
|
||||
udf_column = udf(*[df[col] for col in df.columns])
|
||||
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
|
|
21
python/pyspark/sql/pandas/__init__.py
Normal file
21
python/pyspark/sql/pandas/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
This package includes the internal APIs for PySpark about interoperability
|
||||
between pandas, PySpark and PyArrow. This package should not be directly
|
||||
imported and used.
|
||||
"""
|
431
python/pyspark/sql/pandas/conversion.py
Normal file
431
python/pyspark/sql/pandas/conversion.py
Normal file
|
@ -0,0 +1,431 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
import warnings
|
||||
if sys.version >= '3':
|
||||
basestring = unicode = str
|
||||
xrange = range
|
||||
else:
|
||||
from itertools import izip as zip
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import _load_from_socket
|
||||
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
|
||||
from pyspark.sql.types import IntegralType
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
from pyspark.util import _exception_message
|
||||
|
||||
|
||||
class PandasConversionMixin(object):
|
||||
"""
|
||||
Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame`
|
||||
can use this class.
|
||||
"""
|
||||
|
||||
@since(1.3)
|
||||
def toPandas(self):
|
||||
"""
|
||||
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
|
||||
|
||||
This is only available if Pandas is installed and available.
|
||||
|
||||
.. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is
|
||||
expected to be small, as all the data is loaded into the driver's memory.
|
||||
|
||||
.. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
|
||||
|
||||
>>> df.toPandas() # doctest: +SKIP
|
||||
age name
|
||||
0 2 Alice
|
||||
1 5 Bob
|
||||
"""
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
assert isinstance(self, DataFrame)
|
||||
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if self.sql_ctx._conf.pandasRespectSessionTimeZone():
|
||||
timezone = self.sql_ctx._conf.sessionLocalTimeZone()
|
||||
else:
|
||||
timezone = None
|
||||
|
||||
if self.sql_ctx._conf.arrowPySparkEnabled():
|
||||
use_arrow = True
|
||||
try:
|
||||
from pyspark.sql.pandas.types import to_arrow_schema
|
||||
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pyarrow_version()
|
||||
to_arrow_schema(self.schema)
|
||||
except Exception as e:
|
||||
|
||||
if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
|
||||
"failed by the reason below:\n %s\n"
|
||||
"Attempting non-optimization as "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
|
||||
"true." % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
use_arrow = False
|
||||
else:
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and will not continue because automatic fallback "
|
||||
"with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to "
|
||||
"false.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
|
||||
# Try to use Arrow optimization when the schema is supported and the required version
|
||||
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
|
||||
if use_arrow:
|
||||
try:
|
||||
from pyspark.sql.pandas.types import _check_dataframe_localize_timestamps
|
||||
import pyarrow
|
||||
batches = self._collect_as_arrow()
|
||||
if len(batches) > 0:
|
||||
table = pyarrow.Table.from_batches(batches)
|
||||
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
|
||||
# values, but we should use datetime.date to match the behavior with when
|
||||
# Arrow optimization is disabled.
|
||||
pdf = table.to_pandas(date_as_object=True)
|
||||
return _check_dataframe_localize_timestamps(pdf, timezone)
|
||||
else:
|
||||
return pd.DataFrame.from_records([], columns=self.columns)
|
||||
except Exception as e:
|
||||
# We might have to allow fallback here as well but multiple Spark jobs can
|
||||
# be executed. So, simply fail in this case for now.
|
||||
msg = (
|
||||
"toPandas attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and can not continue. Note that "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an "
|
||||
"effect on failures in the middle of "
|
||||
"computation.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
|
||||
# Below is toPandas without Arrow optimization.
|
||||
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
|
||||
|
||||
dtype = {}
|
||||
for field in self.schema:
|
||||
pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType)
|
||||
# SPARK-21766: if an integer field is nullable and has null values, it can be
|
||||
# inferred by pandas as float column. Once we convert the column with NaN back
|
||||
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
|
||||
# float type, not the corrected type from the schema in this case.
|
||||
if pandas_type is not None and \
|
||||
not(isinstance(field.dataType, IntegralType) and field.nullable and
|
||||
pdf[field.name].isnull().any()):
|
||||
dtype[field.name] = pandas_type
|
||||
# Ensure we fall back to nullable numpy types, even when whole column is null:
|
||||
if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any():
|
||||
dtype[field.name] = np.float64
|
||||
if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any():
|
||||
dtype[field.name] = np.object
|
||||
|
||||
for f, t in dtype.items():
|
||||
pdf[f] = pdf[f].astype(t, copy=False)
|
||||
|
||||
if timezone is None:
|
||||
return pdf
|
||||
else:
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
|
||||
for field in self.schema:
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if isinstance(field.dataType, TimestampType):
|
||||
pdf[field.name] = \
|
||||
_check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
|
||||
return pdf
|
||||
|
||||
@staticmethod
|
||||
def _to_corrected_pandas_type(dt):
|
||||
"""
|
||||
When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type
|
||||
may be wrong. This method gets the corrected data type for Pandas if that type may be
|
||||
inferred incorrectly.
|
||||
"""
|
||||
import numpy as np
|
||||
if type(dt) == ByteType:
|
||||
return np.int8
|
||||
elif type(dt) == ShortType:
|
||||
return np.int16
|
||||
elif type(dt) == IntegerType:
|
||||
return np.int32
|
||||
elif type(dt) == LongType:
|
||||
return np.int64
|
||||
elif type(dt) == FloatType:
|
||||
return np.float32
|
||||
elif type(dt) == DoubleType:
|
||||
return np.float64
|
||||
elif type(dt) == BooleanType:
|
||||
return np.bool
|
||||
elif type(dt) == TimestampType:
|
||||
return np.datetime64
|
||||
else:
|
||||
return None
|
||||
|
||||
def _collect_as_arrow(self):
|
||||
"""
|
||||
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
|
||||
and available on driver and worker Python environments.
|
||||
|
||||
.. note:: Experimental.
|
||||
"""
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
assert isinstance(self, DataFrame)
|
||||
|
||||
with SCCallSiteSync(self._sc):
|
||||
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
|
||||
|
||||
# Collect list of un-ordered batches where last element is a list of correct order indices
|
||||
try:
|
||||
results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer()))
|
||||
finally:
|
||||
# Join serving thread and raise any exceptions from collectAsArrowToPython
|
||||
jsocket_auth_server.getResult()
|
||||
|
||||
# Separate RecordBatches from batch order indices in results
|
||||
batches = results[:-1]
|
||||
batch_order = results[-1]
|
||||
|
||||
# Re-order the batch list using the correct order
|
||||
return [batches[i] for i in batch_order]
|
||||
|
||||
|
||||
class SparkConversionMixin(object):
|
||||
"""
|
||||
Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
|
||||
can use this class.
|
||||
"""
|
||||
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
assert isinstance(self, SparkSession)
|
||||
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
if self._wrapped._conf.pandasRespectSessionTimeZone():
|
||||
timezone = self._wrapped._conf.sessionLocalTimeZone()
|
||||
else:
|
||||
timezone = None
|
||||
|
||||
# If no schema supplied by user then get the names of columns only
|
||||
if schema is None:
|
||||
schema = [str(x) if not isinstance(x, basestring) else
|
||||
(x.encode('utf-8') if not isinstance(x, str) else x)
|
||||
for x in data.columns]
|
||||
|
||||
if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
|
||||
try:
|
||||
return self._create_from_pandas_with_arrow(data, schema, timezone)
|
||||
except Exception as e:
|
||||
from pyspark.util import _exception_message
|
||||
|
||||
if self._wrapped._conf.arrowPySparkFallbackEnabled():
|
||||
msg = (
|
||||
"createDataFrame attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
|
||||
"failed by the reason below:\n %s\n"
|
||||
"Attempting non-optimization as "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
|
||||
"true." % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
msg = (
|
||||
"createDataFrame attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and will not continue because automatic "
|
||||
"fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
|
||||
"has been set to false.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
data = self._convert_from_pandas(data, schema, timezone)
|
||||
return self._create_dataframe(data, schema, samplingRatio, samplingRatio)
|
||||
|
||||
def _convert_from_pandas(self, pdf, schema, timezone):
|
||||
"""
|
||||
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
|
||||
:return list of records
|
||||
"""
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
assert isinstance(self, SparkSession)
|
||||
|
||||
if timezone is not None:
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
|
||||
copied = False
|
||||
if isinstance(schema, StructType):
|
||||
for field in schema:
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if isinstance(field.dataType, TimestampType):
|
||||
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
|
||||
if s is not pdf[field.name]:
|
||||
if not copied:
|
||||
# Copy once if the series is modified to prevent the original
|
||||
# Pandas DataFrame from being updated
|
||||
pdf = pdf.copy()
|
||||
copied = True
|
||||
pdf[field.name] = s
|
||||
else:
|
||||
for column, series in pdf.iteritems():
|
||||
s = _check_series_convert_timestamps_tz_local(series, timezone)
|
||||
if s is not series:
|
||||
if not copied:
|
||||
# Copy once if the series is modified to prevent the original
|
||||
# Pandas DataFrame from being updated
|
||||
pdf = pdf.copy()
|
||||
copied = True
|
||||
pdf[column] = s
|
||||
|
||||
# Convert pandas.DataFrame to list of numpy records
|
||||
np_records = pdf.to_records(index=False)
|
||||
|
||||
# Check if any columns need to be fixed for Spark to infer properly
|
||||
if len(np_records) > 0:
|
||||
record_dtype = self._get_numpy_record_dtype(np_records[0])
|
||||
if record_dtype is not None:
|
||||
return [r.astype(record_dtype).tolist() for r in np_records]
|
||||
|
||||
# Convert list of numpy records to python lists
|
||||
return [r.tolist() for r in np_records]
|
||||
|
||||
def _get_numpy_record_dtype(self, rec):
|
||||
"""
|
||||
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
|
||||
the dtypes of fields in a record so they can be properly loaded into Spark.
|
||||
:param rec: a numpy record to check field dtypes
|
||||
:return corrected dtype for a numpy.record or None if no correction needed
|
||||
"""
|
||||
import numpy as np
|
||||
cur_dtypes = rec.dtype
|
||||
col_names = cur_dtypes.names
|
||||
record_type_list = []
|
||||
has_rec_fix = False
|
||||
for i in xrange(len(cur_dtypes)):
|
||||
curr_type = cur_dtypes[i]
|
||||
# If type is a datetime64 timestamp, convert to microseconds
|
||||
# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,
|
||||
# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417
|
||||
if curr_type == np.dtype('datetime64[ns]'):
|
||||
curr_type = 'datetime64[us]'
|
||||
has_rec_fix = True
|
||||
record_type_list.append((str(col_names[i]), curr_type))
|
||||
return np.dtype(record_type_list) if has_rec_fix else None
|
||||
|
||||
def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
|
||||
"""
|
||||
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
|
||||
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
|
||||
data types will be used to coerce the data in Pandas to Arrow conversion.
|
||||
"""
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
assert isinstance(self, SparkSession)
|
||||
|
||||
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
|
||||
from pyspark.sql.types import TimestampType
|
||||
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
|
||||
require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pandas_version()
|
||||
require_minimum_pyarrow_version()
|
||||
|
||||
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
|
||||
import pyarrow as pa
|
||||
|
||||
# Create the Spark schema from list of names passed in with Arrow types
|
||||
if isinstance(schema, (list, tuple)):
|
||||
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
|
||||
struct = StructType()
|
||||
for name, field in zip(schema, arrow_schema):
|
||||
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
schema = struct
|
||||
|
||||
# Determine arrow types to coerce data when creating batches
|
||||
if isinstance(schema, StructType):
|
||||
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
|
||||
elif isinstance(schema, DataType):
|
||||
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
|
||||
else:
|
||||
# Any timestamps must be coerced to be compatible with Spark
|
||||
arrow_types = [to_arrow_type(TimestampType())
|
||||
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
|
||||
for t in pdf.dtypes]
|
||||
|
||||
# Slice the DataFrame to be batched
|
||||
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
|
||||
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
|
||||
|
||||
# Create list of Arrow (columns, type) for serializer dump_stream
|
||||
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
|
||||
for pdf_slice in pdf_slices]
|
||||
|
||||
jsqlContext = self._wrapped._jsqlContext
|
||||
|
||||
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
|
||||
col_by_name = True # col by name only applies to StructType columns, can't happen here
|
||||
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
|
||||
|
||||
def reader_func(temp_filename):
|
||||
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
|
||||
|
||||
def create_RDD_server():
|
||||
return self._jvm.ArrowRDDServer(jsqlContext)
|
||||
|
||||
# Create Spark DataFrame from Arrow stream file, using one batch per partition
|
||||
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
|
||||
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
|
||||
df = DataFrame(jdf, self._wrapped)
|
||||
df._schema = schema
|
||||
return df
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.pandas.conversion
|
||||
globs = pyspark.sql.pandas.conversion.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.pandas.conversion tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.pandas.conversion, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
539
python/pyspark/sql/pandas/functions.py
Normal file
539
python/pyspark/sql/pandas/functions.py
Normal file
|
@ -0,0 +1,539 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import PythonEvalType
|
||||
from pyspark.sql.types import DataType
|
||||
from pyspark.sql.udf import _create_udf
|
||||
|
||||
|
||||
class PandasUDFType(object):
|
||||
"""Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
|
||||
"""
|
||||
SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
|
||||
SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
|
||||
|
||||
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
|
||||
|
||||
COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
|
||||
|
||||
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
||||
|
||||
MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
|
||||
|
||||
|
||||
@since(2.3)
|
||||
def pandas_udf(f=None, returnType=None, functionType=None):
|
||||
"""
|
||||
Creates a vectorized user defined function (UDF).
|
||||
|
||||
:param f: user-defined function. A python function if used as a standalone function
|
||||
:param returnType: the return type of the user-defined function. The value can be either a
|
||||
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
||||
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
|
||||
Default: SCALAR.
|
||||
|
||||
The function type of the UDF can be one of the following:
|
||||
|
||||
1. SCALAR
|
||||
|
||||
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
|
||||
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
|
||||
If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
|
||||
|
||||
:class:`MapType`, nested :class:`StructType` are currently not supported as output types.
|
||||
|
||||
Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and
|
||||
:meth:`pyspark.sql.DataFrame.select`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> from pyspark.sql.types import IntegerType, StringType
|
||||
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP
|
||||
>>> @pandas_udf(StringType()) # doctest: +SKIP
|
||||
... def to_upper(s):
|
||||
... return s.str.upper()
|
||||
...
|
||||
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def add_one(x):
|
||||
... return x + 1
|
||||
...
|
||||
>>> df = spark.createDataFrame([(1, "John Doe", 21)],
|
||||
... ("id", "name", "age")) # doctest: +SKIP
|
||||
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
|
||||
... .show() # doctest: +SKIP
|
||||
+----------+--------------+------------+
|
||||
|slen(name)|to_upper(name)|add_one(age)|
|
||||
+----------+--------------+------------+
|
||||
| 8| JOHN DOE| 22|
|
||||
+----------+--------------+------------+
|
||||
>>> @pandas_udf("first string, last string") # doctest: +SKIP
|
||||
... def split_expand(n):
|
||||
... return n.str.split(expand=True)
|
||||
>>> df.select(split_expand("name")).show() # doctest: +SKIP
|
||||
+------------------+
|
||||
|split_expand(name)|
|
||||
+------------------+
|
||||
| [John, Doe]|
|
||||
+------------------+
|
||||
|
||||
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
|
||||
column, but is the length of an internal batch used for each call to the function.
|
||||
Therefore, this can be used, for example, to ensure the length of each returned
|
||||
`pandas.Series`, and can not be used as the column length.
|
||||
|
||||
2. SCALAR_ITER
|
||||
|
||||
A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the
|
||||
wrapped Python function takes an iterator of batches as input instead of a single batch and,
|
||||
instead of returning a single output batch, it yields output batches or explicitly returns an
|
||||
generator or an iterator of output batches.
|
||||
It is useful when the UDF execution requires initializing some state, e.g., loading a machine
|
||||
learning model file to apply inference to every input batch.
|
||||
|
||||
.. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all
|
||||
batches from one partition, although it is currently implemented this way.
|
||||
Your code shall not rely on this behavior because it might change in the future for
|
||||
further optimization, e.g., one invocation processes multiple partitions.
|
||||
|
||||
Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
|
||||
:meth:`pyspark.sql.DataFrame.select`.
|
||||
|
||||
>>> import pandas as pd # doctest: +SKIP
|
||||
>>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType
|
||||
>>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP
|
||||
>>> df = spark.createDataFrame(pdf) # doctest: +SKIP
|
||||
|
||||
When the UDF is called with a single column that is not `StructType`, the input to the
|
||||
underlying function is an iterator of `pd.Series`.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def plus_one(batch_iter):
|
||||
... for x in batch_iter:
|
||||
... yield x + 1
|
||||
...
|
||||
>>> df.select(plus_one(col("x"))).show() # doctest: +SKIP
|
||||
+-----------+
|
||||
|plus_one(x)|
|
||||
+-----------+
|
||||
| 2|
|
||||
| 3|
|
||||
| 4|
|
||||
+-----------+
|
||||
|
||||
When the UDF is called with more than one columns, the input to the underlying function is an
|
||||
iterator of `pd.Series` tuple.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def multiply_two_cols(batch_iter):
|
||||
... for a, b in batch_iter:
|
||||
... yield a * b
|
||||
...
|
||||
>>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP
|
||||
+-----------------------+
|
||||
|multiply_two_cols(x, x)|
|
||||
+-----------------------+
|
||||
| 1|
|
||||
| 4|
|
||||
| 9|
|
||||
+-----------------------+
|
||||
|
||||
When the UDF is called with a single column that is `StructType`, the input to the underlying
|
||||
function is an iterator of `pd.DataFrame`.
|
||||
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def multiply_two_nested_cols(pdf_iter):
|
||||
... for pdf in pdf_iter:
|
||||
... yield pdf["a"] * pdf["b"]
|
||||
...
|
||||
>>> df.select(
|
||||
... multiply_two_nested_cols(
|
||||
... struct(col("x").alias("a"), col("x").alias("b"))
|
||||
... ).alias("y")
|
||||
... ).show() # doctest: +SKIP
|
||||
+---+
|
||||
| y|
|
||||
+---+
|
||||
| 1|
|
||||
| 4|
|
||||
| 9|
|
||||
+---+
|
||||
|
||||
In the UDF, you can initialize some states before processing batches, wrap your code with
|
||||
`try ... finally ...` or use context managers to ensure the release of resources at the end
|
||||
or in case of early termination.
|
||||
|
||||
>>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP
|
||||
>>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP
|
||||
... def plus_y(batch_iter):
|
||||
... y = y_bc.value # initialize some state
|
||||
... try:
|
||||
... for x in batch_iter:
|
||||
... yield x + y
|
||||
... finally:
|
||||
... pass # release resources here, if any
|
||||
...
|
||||
>>> df.select(plus_y(col("x"))).show() # doctest: +SKIP
|
||||
+---------+
|
||||
|plus_y(x)|
|
||||
+---------+
|
||||
| 2|
|
||||
| 3|
|
||||
| 4|
|
||||
+---------+
|
||||
|
||||
3. GROUPED_MAP
|
||||
|
||||
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
|
||||
The returnType should be a :class:`StructType` describing the schema of the returned
|
||||
`pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
|
||||
the field names in the defined returnType schema if specified as strings, or match the
|
||||
field data types by position if not strings, e.g. integer indices.
|
||||
The length of the returned `pandas.DataFrame` can be arbitrary.
|
||||
|
||||
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v")) # doctest: +SKIP
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def normalize(pdf):
|
||||
... v = pdf.v
|
||||
... return pdf.assign(v=(v - v.mean()) / v.std())
|
||||
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
|
||||
+---+-------------------+
|
||||
| id| v|
|
||||
+---+-------------------+
|
||||
| 1|-0.7071067811865475|
|
||||
| 1| 0.7071067811865475|
|
||||
| 2|-0.8320502943378437|
|
||||
| 2|-0.2773500981126146|
|
||||
| 2| 1.1094003924504583|
|
||||
+---+-------------------+
|
||||
|
||||
Alternatively, the user can define a function that takes two arguments.
|
||||
In this case, the grouping key(s) will be passed as the first argument and the data will
|
||||
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
|
||||
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
|
||||
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
|
||||
This is useful when the user does not want to hardcode grouping key(s) in the function.
|
||||
|
||||
>>> import pandas as pd # doctest: +SKIP
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v")) # doctest: +SKIP
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def mean_udf(key, pdf):
|
||||
... # key is a tuple of one numpy.int64, which is the value
|
||||
... # of 'id' for the current group
|
||||
... return pd.DataFrame([key + (pdf.v.mean(),)])
|
||||
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id| v|
|
||||
+---+---+
|
||||
| 1|1.5|
|
||||
| 2|6.0|
|
||||
+---+---+
|
||||
>>> @pandas_udf(
|
||||
... "id long, `ceil(v / 2)` long, v double",
|
||||
... PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
>>> def sum_udf(key, pdf):
|
||||
... # key is a tuple of two numpy.int64s, which is the values
|
||||
... # of 'id' and 'ceil(df.v / 2)' for the current group
|
||||
... return pd.DataFrame([key + (pdf.v.sum(),)])
|
||||
>>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP
|
||||
+---+-----------+----+
|
||||
| id|ceil(v / 2)| v|
|
||||
+---+-----------+----+
|
||||
| 2| 5|10.0|
|
||||
| 1| 1| 3.0|
|
||||
| 2| 3| 5.0|
|
||||
| 2| 2| 3.0|
|
||||
+---+-----------+----+
|
||||
|
||||
.. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
|
||||
recommended to explicitly index the columns by name to ensure the positions are correct,
|
||||
or alternatively use an `OrderedDict`.
|
||||
For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
|
||||
`pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
|
||||
|
||||
4. GROUPED_AGG
|
||||
|
||||
A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
|
||||
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
|
||||
The returned scalar can be either a python primitive type, e.g., `int` or `float`
|
||||
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
|
||||
|
||||
:class:`MapType` and :class:`StructType` are currently not supported as output types.
|
||||
|
||||
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
|
||||
:class:`pyspark.sql.Window`
|
||||
|
||||
This example shows using grouped aggregated UDFs with groupby:
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def mean_udf(v):
|
||||
... return v.mean()
|
||||
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
|
||||
+---+-----------+
|
||||
| id|mean_udf(v)|
|
||||
+---+-----------+
|
||||
| 1| 1.5|
|
||||
| 2| 6.0|
|
||||
+---+-----------+
|
||||
|
||||
This example shows using grouped aggregated UDFs as window functions.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> from pyspark.sql import Window
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def mean_udf(v):
|
||||
... return v.mean()
|
||||
>>> w = (Window.partitionBy('id')
|
||||
... .orderBy('v')
|
||||
... .rowsBetween(-1, 0))
|
||||
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
|
||||
+---+----+------+
|
||||
| id| v|mean_v|
|
||||
+---+----+------+
|
||||
| 1| 1.0| 1.0|
|
||||
| 1| 2.0| 1.5|
|
||||
| 2| 3.0| 3.0|
|
||||
| 2| 5.0| 4.0|
|
||||
| 2|10.0| 7.5|
|
||||
+---+----+------+
|
||||
|
||||
.. note:: For performance reasons, the input series to window functions are not copied.
|
||||
Therefore, mutating the input series is not allowed and will cause incorrect results.
|
||||
For the same reason, users should also not rely on the index of the input series.
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
|
||||
|
||||
5. MAP_ITER
|
||||
|
||||
A map iterator Pandas UDFs are used to transform data with an iterator of batches.
|
||||
It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`.
|
||||
|
||||
It can return the output of arbitrary length in contrast to the scalar Pandas UDF.
|
||||
It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
|
||||
function and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another
|
||||
iterator of `pandas.DataFrame`\\s. All columns are passed together as an
|
||||
iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of
|
||||
`pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
|
||||
|
||||
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
|
||||
... ("id", "age")) # doctest: +SKIP
|
||||
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
|
||||
... def filter_func(batch_iter):
|
||||
... for pdf in batch_iter:
|
||||
... yield pdf[pdf.id == 1]
|
||||
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id|age|
|
||||
+---+---+
|
||||
| 1| 21|
|
||||
+---+---+
|
||||
|
||||
6. COGROUPED_MAP
|
||||
|
||||
A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) ->
|
||||
`pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema
|
||||
of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame`
|
||||
must either match the field names in the defined `returnType` schema if specified as strings,
|
||||
or match the field data types by position if not strings, e.g. integer indices. The length
|
||||
of the returned `pandas.DataFrame` can be arbitrary.
|
||||
|
||||
CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df1 = spark.createDataFrame(
|
||||
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||
... ("time", "id", "v1"))
|
||||
>>> df2 = spark.createDataFrame(
|
||||
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||
... ("time", "id", "v2"))
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(l, r):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+---------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+---------+---+---+---+
|
||||
| 20000101| 1|1.0| x|
|
||||
| 20000102| 1|3.0| x|
|
||||
| 20000101| 2|2.0| y|
|
||||
| 20000102| 2|4.0| y|
|
||||
+---------+---+---+---+
|
||||
|
||||
Alternatively, the user can define a function that takes three arguments. In this case,
|
||||
the grouping key(s) will be passed as the first argument and the data will be passed as the
|
||||
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
|
||||
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
|
||||
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(k, l, r):
|
||||
... if k == (1,):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
... else:
|
||||
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+---------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+---------+---+---+---+
|
||||
| 20000101| 1|1.0| x|
|
||||
| 20000102| 1|3.0| x|
|
||||
+---------+---+---+---+
|
||||
|
||||
.. note:: The user-defined functions are considered deterministic by default. Due to
|
||||
optimization, duplicate invocations may be eliminated or the function may even be invoked
|
||||
more times than it is present in the query. If your function is not deterministic, call
|
||||
`asNondeterministic` on the user defined function. E.g.:
|
||||
|
||||
>>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def random(v):
|
||||
... import numpy as np
|
||||
... import pandas as pd
|
||||
... return pd.Series(np.random.randn(len(v))
|
||||
>>> random = random.asNondeterministic() # doctest: +SKIP
|
||||
|
||||
.. note:: The user-defined functions do not support conditional expressions or short circuiting
|
||||
in boolean expressions and it ends up with being executed all internally. If the functions
|
||||
can fail on special rows, the workaround is to incorporate the condition into the functions.
|
||||
|
||||
.. note:: The user-defined functions do not take keyword arguments on the calling side.
|
||||
|
||||
.. note:: The data type of returned `pandas.Series` from the user-defined functions should be
|
||||
matched with defined returnType (see :meth:`types.to_arrow_type` and
|
||||
:meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do
|
||||
conversion on returned data. The conversion is not guaranteed to be correct and results
|
||||
should be checked for accuracy by users.
|
||||
"""
|
||||
|
||||
# The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that
|
||||
# are not yet visible to the user. Some of behaviors are buggy and might be changed in the near
|
||||
# future. The table might have to be eventually documented externally.
|
||||
# Please see SPARK-28132's PR to see the codes in order to generate the table below.
|
||||
#
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
# |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
# | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa
|
||||
# | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa
|
||||
# | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa
|
||||
# | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa
|
||||
# | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa
|
||||
# | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa
|
||||
# | array<int>| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa
|
||||
# | map<string,int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa
|
||||
# | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa
|
||||
# +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa
|
||||
#
|
||||
# Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be
|
||||
# used in `returnType`.
|
||||
# Note: The values inside of the table are generated by `repr`.
|
||||
# Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used.
|
||||
# Note: Timezone is KST.
|
||||
# Note: 'X' means it throws an exception during the conversion.
|
||||
|
||||
# decorator @pandas_udf(returnType, functionType)
|
||||
is_decorator = f is None or isinstance(f, (str, DataType))
|
||||
|
||||
if is_decorator:
|
||||
# If DataType has been passed as a positional argument
|
||||
# for decorator use it as a returnType
|
||||
return_type = f or returnType
|
||||
|
||||
if functionType is not None:
|
||||
# @pandas_udf(dataType, functionType=functionType)
|
||||
# @pandas_udf(returnType=dataType, functionType=functionType)
|
||||
eval_type = functionType
|
||||
elif returnType is not None and isinstance(returnType, int):
|
||||
# @pandas_udf(dataType, functionType)
|
||||
eval_type = returnType
|
||||
else:
|
||||
# @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
|
||||
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
else:
|
||||
return_type = returnType
|
||||
|
||||
if functionType is not None:
|
||||
eval_type = functionType
|
||||
else:
|
||||
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
|
||||
|
||||
if return_type is None:
|
||||
raise ValueError("Invalid returnType: returnType can not be None")
|
||||
|
||||
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
|
||||
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
|
||||
raise ValueError("Invalid functionType: "
|
||||
"functionType must be one the values from PandasUDFType")
|
||||
|
||||
if is_decorator:
|
||||
return functools.partial(_create_udf, returnType=return_type, evalType=eval_type)
|
||||
else:
|
||||
return _create_udf(f=f, returnType=return_type, evalType=eval_type)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.pandas.functions
|
||||
globs = pyspark.sql.pandas.functions.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.pandas.functions tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.pandas.functions, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
|
@ -22,7 +22,83 @@ from pyspark.sql.column import Column
|
|||
from pyspark.sql.dataframe import DataFrame
|
||||
|
||||
|
||||
class CoGroupedData(object):
|
||||
class PandasGroupedOpsMixin(object):
|
||||
"""
|
||||
Min-in for pandas grouped operations. Currently, only :class:`GroupedData`
|
||||
can use this class.
|
||||
"""
|
||||
|
||||
def apply(self, udf):
|
||||
"""
|
||||
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
|
||||
as a `DataFrame`.
|
||||
|
||||
The user-defined function should take a `pandas.DataFrame` and return another
|
||||
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
|
||||
to the user-function and the returned `pandas.DataFrame` are combined as a
|
||||
:class:`DataFrame`.
|
||||
|
||||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||
returnType of the pandas udf.
|
||||
|
||||
.. note:: This function requires a full shuffle. All the data of a group will be loaded
|
||||
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
||||
and certain groups are too large to fit in memory.
|
||||
|
||||
:param udf: a grouped map user-defined function returned by
|
||||
:func:`pyspark.sql.functions.pandas_udf`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
||||
... def normalize(pdf):
|
||||
... v = pdf.v
|
||||
... return pdf.assign(v=(v - v.mean()) / v.std())
|
||||
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
|
||||
+---+-------------------+
|
||||
| id| v|
|
||||
+---+-------------------+
|
||||
| 1|-0.7071067811865475|
|
||||
| 1| 0.7071067811865475|
|
||||
| 2|-0.8320502943378437|
|
||||
| 2|-0.2773500981126146|
|
||||
| 2| 1.1094003924504583|
|
||||
+---+-------------------+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
from pyspark.sql import GroupedData
|
||||
|
||||
assert isinstance(self, GroupedData)
|
||||
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"GROUPED_MAP.")
|
||||
df = self._df
|
||||
udf_column = udf(*[df[col] for col in df.columns])
|
||||
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@since(3.0)
|
||||
def cogroup(self, other):
|
||||
"""
|
||||
Cogroups this group with another group so that we can run cogrouped operations.
|
||||
|
||||
See :class:`CoGroupedData` for the operations that can be run.
|
||||
"""
|
||||
from pyspark.sql import GroupedData
|
||||
|
||||
assert isinstance(self, GroupedData)
|
||||
|
||||
return PandasCogroupedOps(self, other)
|
||||
|
||||
|
||||
class PandasCogroupedOps(object):
|
||||
"""
|
||||
A logical grouping of two :class:`GroupedData`,
|
||||
created by :func:`GroupedData.cogroup`.
|
||||
|
@ -124,15 +200,15 @@ class CoGroupedData(object):
|
|||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.cogroup
|
||||
globs = pyspark.sql.cogroup.__dict__.copy()
|
||||
import pyspark.sql.pandas.group_ops
|
||||
globs = pyspark.sql.pandas.group_ops.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.cogroup tests")\
|
||||
.appName("sql.pandas.group tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.cogroup, globs=globs,
|
||||
pyspark.sql.pandas.group_ops, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
spark.stop()
|
||||
if failure_count:
|
96
python/pyspark/sql/pandas/map_ops.py
Normal file
96
python/pyspark/sql/pandas/map_ops.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import PythonEvalType
|
||||
|
||||
|
||||
class PandasMapOpsMixin(object):
|
||||
"""
|
||||
Min-in for pandas map operations. Currently, only :class:`DataFrame`
|
||||
can use this class.
|
||||
"""
|
||||
|
||||
@since(3.0)
|
||||
def mapInPandas(self, udf):
|
||||
"""
|
||||
Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined
|
||||
function and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The user-defined function should take an iterator of `pandas.DataFrame`\\s and return
|
||||
another iterator of `pandas.DataFrame`\\s. All columns are passed
|
||||
together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the
|
||||
returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
|
||||
Each `pandas.DataFrame` size can be controlled by
|
||||
`spark.sql.execution.arrow.maxRecordsPerBatch`.
|
||||
Its schema must match the returnType of the Pandas user-defined function.
|
||||
|
||||
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame([(1, 21), (2, 30)],
|
||||
... ("id", "age")) # doctest: +SKIP
|
||||
>>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP
|
||||
... def filter_func(batch_iter):
|
||||
... for pdf in batch_iter:
|
||||
... yield pdf[pdf.id == 1]
|
||||
>>> df.mapInPandas(filter_func).show() # doctest: +SKIP
|
||||
+---+---+
|
||||
| id|age|
|
||||
+---+---+
|
||||
| 1| 21|
|
||||
+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
from pyspark.sql import Column, DataFrame
|
||||
|
||||
assert isinstance(self, DataFrame)
|
||||
|
||||
# Columns are special because hasattr always return True
|
||||
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
||||
or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
|
||||
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
||||
"MAP_ITER.")
|
||||
|
||||
udf_column = udf(*[self[col] for col in self.columns])
|
||||
jdf = self._jdf.mapInPandas(udf_column._jc.expr())
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.pandas.map_ops
|
||||
globs = pyspark.sql.pandas.map_ops.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.pandas.map_ops tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.pandas.map_ops, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
281
python/pyspark/sql/pandas/serializers.py
Normal file
281
python/pyspark/sql/pandas/serializers.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
|
||||
"""
|
||||
|
||||
import sys
|
||||
if sys.version < '3':
|
||||
from itertools import izip as zip
|
||||
else:
|
||||
basestring = unicode = str
|
||||
xrange = range
|
||||
|
||||
from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer
|
||||
|
||||
|
||||
class SpecialLengths(object):
|
||||
END_OF_DATA_SECTION = -1
|
||||
PYTHON_EXCEPTION_THROWN = -2
|
||||
TIMING_DATA = -3
|
||||
END_OF_STREAM = -4
|
||||
NULL = -5
|
||||
START_ARROW_STREAM = -6
|
||||
|
||||
|
||||
class ArrowCollectSerializer(Serializer):
|
||||
"""
|
||||
Deserialize a stream of batches followed by batch order information. Used in
|
||||
PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython()
|
||||
in the JVM.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.serializer = ArrowStreamSerializer()
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
return self.serializer.dump_stream(iterator, stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
|
||||
a list of indices that can be used to put the RecordBatches in the correct order.
|
||||
"""
|
||||
# load the batches
|
||||
for batch in self.serializer.load_stream(stream):
|
||||
yield batch
|
||||
|
||||
# load the batch order indices or propagate any error that occurred in the JVM
|
||||
num = read_int(stream)
|
||||
if num == -1:
|
||||
error_msg = UTF8Deserializer().loads(stream)
|
||||
raise RuntimeError("An error occurred while calling "
|
||||
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
|
||||
batch_order = []
|
||||
for i in xrange(num):
|
||||
index = read_int(stream)
|
||||
batch_order.append(index)
|
||||
yield batch_order
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowCollectSerializer(%s)" % self.serializer
|
||||
|
||||
|
||||
class ArrowStreamSerializer(Serializer):
|
||||
"""
|
||||
Serializes Arrow record batches as a stream.
|
||||
"""
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
import pyarrow as pa
|
||||
writer = None
|
||||
try:
|
||||
for batch in iterator:
|
||||
if writer is None:
|
||||
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
|
||||
writer.write_batch(batch)
|
||||
finally:
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
|
||||
def load_stream(self, stream):
|
||||
import pyarrow as pa
|
||||
reader = pa.ipc.open_stream(stream)
|
||||
for batch in reader:
|
||||
yield batch
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamSerializer"
|
||||
|
||||
|
||||
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
||||
"""
|
||||
Serializes Pandas.Series as Arrow data with Arrow streaming format.
|
||||
|
||||
:param timezone: A timezone to respect when handling timestamp values
|
||||
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
|
||||
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
|
||||
"""
|
||||
|
||||
def __init__(self, timezone, safecheck, assign_cols_by_name):
|
||||
super(ArrowStreamPandasSerializer, self).__init__()
|
||||
self._timezone = timezone
|
||||
self._safecheck = safecheck
|
||||
self._assign_cols_by_name = assign_cols_by_name
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
from pyspark.sql.pandas.types import _check_series_localize_timestamps
|
||||
|
||||
# If the given column is a date type column, creates a series of datetime.date directly
|
||||
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
|
||||
# datetime64[ns] type handling.
|
||||
s = arrow_column.to_pandas(date_as_object=True)
|
||||
|
||||
s = _check_series_localize_timestamps(s, self._timezone)
|
||||
return s
|
||||
|
||||
def _create_batch(self, series):
|
||||
"""
|
||||
Create an Arrow record batch from the given pandas.Series or list of Series,
|
||||
with optional type.
|
||||
|
||||
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
|
||||
:return: Arrow RecordBatch
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
|
||||
# Make input conform to [(series1, type1), (series2, type2), ...]
|
||||
if not isinstance(series, (list, tuple)) or \
|
||||
(len(series) == 2 and isinstance(series[1], pa.DataType)):
|
||||
series = [series]
|
||||
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
|
||||
|
||||
def create_array(s, t):
|
||||
mask = s.isnull()
|
||||
# Ensure timestamp series are in expected form for Spark internal representation
|
||||
if t is not None and pa.types.is_timestamp(t):
|
||||
s = _check_series_convert_timestamps_internal(s, self._timezone)
|
||||
try:
|
||||
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
|
||||
except pa.ArrowException as e:
|
||||
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
|
||||
"Array (%s). It can be caused by overflows or other unsafe " + \
|
||||
"conversions warned by Arrow. Arrow safe type check can be " + \
|
||||
"disabled by using SQL config " + \
|
||||
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
|
||||
raise RuntimeError(error_msg % (s.dtype, t), e)
|
||||
return array
|
||||
|
||||
arrs = []
|
||||
for s, t in series:
|
||||
if t is not None and pa.types.is_struct(t):
|
||||
if not isinstance(s, pd.DataFrame):
|
||||
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
|
||||
"but got: %s" % str(type(s)))
|
||||
|
||||
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
|
||||
if len(s) == 0 and len(s.columns) == 0:
|
||||
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
|
||||
# Assign result columns by schema name if user labeled with strings
|
||||
elif self._assign_cols_by_name and any(isinstance(name, basestring)
|
||||
for name in s.columns):
|
||||
arrs_names = [(create_array(s[field.name], field.type), field.name)
|
||||
for field in t]
|
||||
# Assign result columns by position
|
||||
else:
|
||||
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
|
||||
for i, field in enumerate(t)]
|
||||
|
||||
struct_arrs, struct_names = zip(*arrs_names)
|
||||
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
|
||||
else:
|
||||
arrs.append(create_array(s, t))
|
||||
|
||||
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
|
||||
a list of series accompanied by an optional pyarrow type to coerce the data to.
|
||||
"""
|
||||
batches = (self._create_batch(series) for series in iterator)
|
||||
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
||||
"""
|
||||
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
|
||||
import pyarrow as pa
|
||||
for batch in batches:
|
||||
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamPandasSerializer"
|
||||
|
||||
|
||||
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
|
||||
"""
|
||||
Serializer used by Python worker to evaluate Pandas UDFs
|
||||
"""
|
||||
|
||||
def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
|
||||
super(ArrowStreamPandasUDFSerializer, self) \
|
||||
.__init__(timezone, safecheck, assign_cols_by_name)
|
||||
self._df_for_struct = df_for_struct
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
import pyarrow.types as types
|
||||
|
||||
if self._df_for_struct and types.is_struct(arrow_column.type):
|
||||
import pandas as pd
|
||||
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
|
||||
.rename(field.name)
|
||||
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
|
||||
s = pd.concat(series, axis=1)
|
||||
else:
|
||||
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
|
||||
return s
|
||||
|
||||
def dump_stream(self, iterator, stream):
|
||||
"""
|
||||
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
|
||||
This should be sent after creating the first record batch so in case of an error, it can
|
||||
be sent back to the JVM before the Arrow stream starts.
|
||||
"""
|
||||
|
||||
def init_stream_yield_batches():
|
||||
should_write_start_length = True
|
||||
for series in iterator:
|
||||
batch = self._create_batch(series)
|
||||
if should_write_start_length:
|
||||
write_int(SpecialLengths.START_ARROW_STREAM, stream)
|
||||
should_write_start_length = False
|
||||
yield batch
|
||||
|
||||
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowStreamPandasUDFSerializer"
|
||||
|
||||
|
||||
class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
|
||||
|
||||
def load_stream(self, stream):
|
||||
"""
|
||||
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
|
||||
lists of pandas.Series.
|
||||
"""
|
||||
import pyarrow as pa
|
||||
dataframes_in_group = None
|
||||
|
||||
while dataframes_in_group is None or dataframes_in_group > 0:
|
||||
dataframes_in_group = read_int(stream)
|
||||
|
||||
if dataframes_in_group == 2:
|
||||
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
|
||||
yield (
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
|
||||
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
|
||||
)
|
||||
|
||||
elif dataframes_in_group != 0:
|
||||
raise ValueError(
|
||||
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))
|
284
python/pyspark/sql/pandas/types.py
Normal file
284
python/pyspark/sql/pandas/types.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Type-specific codes between pandas and PyArrow. Also contains some utils to correct
|
||||
pandas instances during the type conversion.
|
||||
"""
|
||||
|
||||
from pyspark.sql.types import *
|
||||
|
||||
|
||||
def to_arrow_type(dt):
|
||||
""" Convert Spark data type to pyarrow type
|
||||
"""
|
||||
import pyarrow as pa
|
||||
if type(dt) == BooleanType:
|
||||
arrow_type = pa.bool_()
|
||||
elif type(dt) == ByteType:
|
||||
arrow_type = pa.int8()
|
||||
elif type(dt) == ShortType:
|
||||
arrow_type = pa.int16()
|
||||
elif type(dt) == IntegerType:
|
||||
arrow_type = pa.int32()
|
||||
elif type(dt) == LongType:
|
||||
arrow_type = pa.int64()
|
||||
elif type(dt) == FloatType:
|
||||
arrow_type = pa.float32()
|
||||
elif type(dt) == DoubleType:
|
||||
arrow_type = pa.float64()
|
||||
elif type(dt) == DecimalType:
|
||||
arrow_type = pa.decimal128(dt.precision, dt.scale)
|
||||
elif type(dt) == StringType:
|
||||
arrow_type = pa.string()
|
||||
elif type(dt) == BinaryType:
|
||||
arrow_type = pa.binary()
|
||||
elif type(dt) == DateType:
|
||||
arrow_type = pa.date32()
|
||||
elif type(dt) == TimestampType:
|
||||
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
|
||||
arrow_type = pa.timestamp('us', tz='UTC')
|
||||
elif type(dt) == ArrayType:
|
||||
if type(dt.elementType) in [StructType, TimestampType]:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
arrow_type = pa.list_(to_arrow_type(dt.elementType))
|
||||
elif type(dt) == StructType:
|
||||
if any(type(field.dataType) == StructType for field in dt):
|
||||
raise TypeError("Nested StructType not supported in conversion to Arrow")
|
||||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
|
||||
for field in dt]
|
||||
arrow_type = pa.struct(fields)
|
||||
else:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
return arrow_type
|
||||
|
||||
|
||||
def to_arrow_schema(schema):
|
||||
""" Convert a schema from Spark to Arrow
|
||||
"""
|
||||
import pyarrow as pa
|
||||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
|
||||
for field in schema]
|
||||
return pa.schema(fields)
|
||||
|
||||
|
||||
def from_arrow_type(at):
|
||||
""" Convert pyarrow type to Spark data type.
|
||||
"""
|
||||
import pyarrow.types as types
|
||||
if types.is_boolean(at):
|
||||
spark_type = BooleanType()
|
||||
elif types.is_int8(at):
|
||||
spark_type = ByteType()
|
||||
elif types.is_int16(at):
|
||||
spark_type = ShortType()
|
||||
elif types.is_int32(at):
|
||||
spark_type = IntegerType()
|
||||
elif types.is_int64(at):
|
||||
spark_type = LongType()
|
||||
elif types.is_float32(at):
|
||||
spark_type = FloatType()
|
||||
elif types.is_float64(at):
|
||||
spark_type = DoubleType()
|
||||
elif types.is_decimal(at):
|
||||
spark_type = DecimalType(precision=at.precision, scale=at.scale)
|
||||
elif types.is_string(at):
|
||||
spark_type = StringType()
|
||||
elif types.is_binary(at):
|
||||
spark_type = BinaryType()
|
||||
elif types.is_date32(at):
|
||||
spark_type = DateType()
|
||||
elif types.is_timestamp(at):
|
||||
spark_type = TimestampType()
|
||||
elif types.is_list(at):
|
||||
if types.is_timestamp(at.value_type):
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
spark_type = ArrayType(from_arrow_type(at.value_type))
|
||||
elif types.is_struct(at):
|
||||
if any(types.is_struct(field.type) for field in at):
|
||||
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
|
||||
return StructType(
|
||||
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
for field in at])
|
||||
else:
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
return spark_type
|
||||
|
||||
|
||||
def from_arrow_schema(arrow_schema):
|
||||
""" Convert schema from Arrow to Spark.
|
||||
"""
|
||||
return StructType(
|
||||
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
for field in arrow_schema])
|
||||
|
||||
|
||||
def _get_local_timezone():
|
||||
""" Get local timezone using pytz with environment variable, or dateutil.
|
||||
|
||||
If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
|
||||
string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
|
||||
it reads system configuration to know the system local timezone.
|
||||
|
||||
See also:
|
||||
- https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
|
||||
- https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
|
||||
"""
|
||||
import os
|
||||
return os.environ.get('TZ', 'dateutil/:')
|
||||
|
||||
|
||||
def _check_series_localize_timestamps(s, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
|
||||
|
||||
If the input series is not a timestamp series, then the same series is returned. If the input
|
||||
series is a timestamp series, then a converted series is returned.
|
||||
|
||||
:param s: pandas.Series
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.Series that have been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64tz_dtype
|
||||
tz = timezone or _get_local_timezone()
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert(tz).dt.tz_localize(None)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param pdf: pandas.DataFrame
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
for column, series in pdf.iteritems():
|
||||
pdf[column] = _check_series_localize_timestamps(series, timezone)
|
||||
return pdf
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_internal(s, timezone):
|
||||
"""
|
||||
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
|
||||
Spark internal storage
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
|
||||
"""
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64_dtype(s.dtype):
|
||||
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
|
||||
# timestamp is during the hour when the clock is adjusted backward during due to
|
||||
# daylight saving time (dst).
|
||||
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
|
||||
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
|
||||
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
|
||||
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
|
||||
#
|
||||
# Here we explicit choose to use standard time. This matches the default behavior of
|
||||
# pytz.
|
||||
#
|
||||
# Here are some code to help understand this behavior:
|
||||
# >>> import datetime
|
||||
# >>> import pandas as pd
|
||||
# >>> import pytz
|
||||
# >>>
|
||||
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
|
||||
# >>> ts = pd.Series([t])
|
||||
# >>> tz = pytz.timezone('America/New_York')
|
||||
# >>>
|
||||
# >>> ts.dt.tz_localize(tz, ambiguous=True)
|
||||
# 0 2015-11-01 01:30:00-04:00
|
||||
# dtype: datetime64[ns, America/New_York]
|
||||
# >>>
|
||||
# >>> ts.dt.tz_localize(tz, ambiguous=False)
|
||||
# 0 2015-11-01 01:30:00-05:00
|
||||
# dtype: datetime64[ns, America/New_York]
|
||||
# >>>
|
||||
# >>> str(tz.localize(t))
|
||||
# '2015-11-01 01:30:00-05:00'
|
||||
tz = timezone or _get_local_timezone()
|
||||
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
|
||||
elif is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert('UTC')
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param from_timezone: the timezone to convert from. if None then use local timezone
|
||||
:param to_timezone: the timezone to convert to. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
|
||||
from_tz = from_timezone or _get_local_timezone()
|
||||
to_tz = to_timezone or _get_local_timezone()
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
|
||||
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
|
||||
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
|
||||
return s.apply(
|
||||
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
|
||||
if ts is not pd.NaT else pd.NaT)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_local_tz(s, timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert to. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
return _check_series_convert_timestamps_localize(s, None, timezone)
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_tz_local(s, timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert from. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
return _check_series_convert_timestamps_localize(s, timezone, None)
|
60
python/pyspark/sql/pandas/utils.py
Normal file
60
python/pyspark/sql/pandas/utils.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def require_minimum_pandas_version():
|
||||
""" Raise ImportError if minimum version of Pandas is not installed
|
||||
"""
|
||||
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
|
||||
minimum_pandas_version = "0.23.2"
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
try:
|
||||
import pandas
|
||||
have_pandas = True
|
||||
except ImportError:
|
||||
have_pandas = False
|
||||
if not have_pandas:
|
||||
raise ImportError("Pandas >= %s must be installed; however, "
|
||||
"it was not found." % minimum_pandas_version)
|
||||
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
|
||||
raise ImportError("Pandas >= %s must be installed; however, "
|
||||
"your version was %s." % (minimum_pandas_version, pandas.__version__))
|
||||
|
||||
|
||||
def require_minimum_pyarrow_version():
|
||||
""" Raise ImportError if minimum version of pyarrow is not installed
|
||||
"""
|
||||
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
|
||||
minimum_pyarrow_version = "0.15.1"
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import os
|
||||
try:
|
||||
import pyarrow
|
||||
have_arrow = True
|
||||
except ImportError:
|
||||
have_arrow = False
|
||||
if not have_arrow:
|
||||
raise ImportError("PyArrow >= %s must be installed; however, "
|
||||
"it was not found." % minimum_pyarrow_version)
|
||||
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
|
||||
raise ImportError("PyArrow >= %s must be installed; however, "
|
||||
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
|
||||
if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1":
|
||||
raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, "
|
||||
"please unset ARROW_PRE_0_15_IPC_FORMAT")
|
|
@ -15,6 +15,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
# To disallow implicit relative import. Remove this once we drop Python 2.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
import warnings
|
||||
|
@ -25,15 +27,16 @@ if sys.version >= '3':
|
|||
basestring = unicode = str
|
||||
xrange = range
|
||||
else:
|
||||
from itertools import izip as zip, imap as map
|
||||
from itertools import imap as map
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||
from pyspark.sql.conf import RuntimeConfig
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.pandas.conversion import SparkConversionMixin
|
||||
from pyspark.sql.readwriter import DataFrameReader
|
||||
from pyspark.sql.streaming import DataStreamReader
|
||||
from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \
|
||||
from pyspark.sql.types import Row, DataType, StringType, StructType, \
|
||||
_make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
|
||||
_parse_datatype_string
|
||||
from pyspark.sql.utils import install_exception_handler
|
||||
|
@ -60,7 +63,7 @@ def _monkey_patch_RDD(sparkSession):
|
|||
RDD.toDF = toDF
|
||||
|
||||
|
||||
class SparkSession(object):
|
||||
class SparkSession(SparkConversionMixin):
|
||||
"""The entry point to programming Spark with the Dataset and DataFrame API.
|
||||
|
||||
A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
|
||||
|
@ -458,135 +461,6 @@ class SparkSession(object):
|
|||
data = [schema.toInternal(row) for row in data]
|
||||
return self._sc.parallelize(data), schema
|
||||
|
||||
def _get_numpy_record_dtype(self, rec):
|
||||
"""
|
||||
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
|
||||
the dtypes of fields in a record so they can be properly loaded into Spark.
|
||||
:param rec: a numpy record to check field dtypes
|
||||
:return corrected dtype for a numpy.record or None if no correction needed
|
||||
"""
|
||||
import numpy as np
|
||||
cur_dtypes = rec.dtype
|
||||
col_names = cur_dtypes.names
|
||||
record_type_list = []
|
||||
has_rec_fix = False
|
||||
for i in xrange(len(cur_dtypes)):
|
||||
curr_type = cur_dtypes[i]
|
||||
# If type is a datetime64 timestamp, convert to microseconds
|
||||
# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,
|
||||
# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417
|
||||
if curr_type == np.dtype('datetime64[ns]'):
|
||||
curr_type = 'datetime64[us]'
|
||||
has_rec_fix = True
|
||||
record_type_list.append((str(col_names[i]), curr_type))
|
||||
return np.dtype(record_type_list) if has_rec_fix else None
|
||||
|
||||
def _convert_from_pandas(self, pdf, schema, timezone):
|
||||
"""
|
||||
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
|
||||
:return list of records
|
||||
"""
|
||||
if timezone is not None:
|
||||
from pyspark.sql.types import _check_series_convert_timestamps_tz_local
|
||||
copied = False
|
||||
if isinstance(schema, StructType):
|
||||
for field in schema:
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if isinstance(field.dataType, TimestampType):
|
||||
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
|
||||
if s is not pdf[field.name]:
|
||||
if not copied:
|
||||
# Copy once if the series is modified to prevent the original
|
||||
# Pandas DataFrame from being updated
|
||||
pdf = pdf.copy()
|
||||
copied = True
|
||||
pdf[field.name] = s
|
||||
else:
|
||||
for column, series in pdf.iteritems():
|
||||
s = _check_series_convert_timestamps_tz_local(series, timezone)
|
||||
if s is not series:
|
||||
if not copied:
|
||||
# Copy once if the series is modified to prevent the original
|
||||
# Pandas DataFrame from being updated
|
||||
pdf = pdf.copy()
|
||||
copied = True
|
||||
pdf[column] = s
|
||||
|
||||
# Convert pandas.DataFrame to list of numpy records
|
||||
np_records = pdf.to_records(index=False)
|
||||
|
||||
# Check if any columns need to be fixed for Spark to infer properly
|
||||
if len(np_records) > 0:
|
||||
record_dtype = self._get_numpy_record_dtype(np_records[0])
|
||||
if record_dtype is not None:
|
||||
return [r.astype(record_dtype).tolist() for r in np_records]
|
||||
|
||||
# Convert list of numpy records to python lists
|
||||
return [r.tolist() for r in np_records]
|
||||
|
||||
def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
|
||||
"""
|
||||
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
|
||||
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
|
||||
data types will be used to coerce the data in Pandas to Arrow conversion.
|
||||
"""
|
||||
from pyspark.serializers import ArrowStreamPandasSerializer
|
||||
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
|
||||
from pyspark.sql.utils import require_minimum_pandas_version, \
|
||||
require_minimum_pyarrow_version
|
||||
|
||||
require_minimum_pandas_version()
|
||||
require_minimum_pyarrow_version()
|
||||
|
||||
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
|
||||
import pyarrow as pa
|
||||
|
||||
# Create the Spark schema from list of names passed in with Arrow types
|
||||
if isinstance(schema, (list, tuple)):
|
||||
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
|
||||
struct = StructType()
|
||||
for name, field in zip(schema, arrow_schema):
|
||||
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
schema = struct
|
||||
|
||||
# Determine arrow types to coerce data when creating batches
|
||||
if isinstance(schema, StructType):
|
||||
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
|
||||
elif isinstance(schema, DataType):
|
||||
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
|
||||
else:
|
||||
# Any timestamps must be coerced to be compatible with Spark
|
||||
arrow_types = [to_arrow_type(TimestampType())
|
||||
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
|
||||
for t in pdf.dtypes]
|
||||
|
||||
# Slice the DataFrame to be batched
|
||||
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
|
||||
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
|
||||
|
||||
# Create list of Arrow (columns, type) for serializer dump_stream
|
||||
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
|
||||
for pdf_slice in pdf_slices]
|
||||
|
||||
jsqlContext = self._wrapped._jsqlContext
|
||||
|
||||
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
|
||||
col_by_name = True # col by name only applies to StructType columns, can't happen here
|
||||
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
|
||||
|
||||
def reader_func(temp_filename):
|
||||
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
|
||||
|
||||
def create_RDD_server():
|
||||
return self._jvm.ArrowRDDServer(jsqlContext)
|
||||
|
||||
# Create Spark DataFrame from Arrow stream file, using one batch per partition
|
||||
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
|
||||
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
|
||||
df = DataFrame(jdf, self._wrapped)
|
||||
df._schema = schema
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def _create_shell_session():
|
||||
"""
|
||||
|
@ -722,46 +596,12 @@ class SparkSession(object):
|
|||
except Exception:
|
||||
has_pandas = False
|
||||
if has_pandas and isinstance(data, pandas.DataFrame):
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
if self._wrapped._conf.pandasRespectSessionTimeZone():
|
||||
timezone = self._wrapped._conf.sessionLocalTimeZone()
|
||||
else:
|
||||
timezone = None
|
||||
|
||||
# If no schema supplied by user then get the names of columns only
|
||||
if schema is None:
|
||||
schema = [str(x) if not isinstance(x, basestring) else
|
||||
(x.encode('utf-8') if not isinstance(x, str) else x)
|
||||
for x in data.columns]
|
||||
|
||||
if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
|
||||
try:
|
||||
return self._create_from_pandas_with_arrow(data, schema, timezone)
|
||||
except Exception as e:
|
||||
from pyspark.util import _exception_message
|
||||
|
||||
if self._wrapped._conf.arrowPySparkFallbackEnabled():
|
||||
msg = (
|
||||
"createDataFrame attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
|
||||
"failed by the reason below:\n %s\n"
|
||||
"Attempting non-optimization as "
|
||||
"'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to "
|
||||
"true." % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
msg = (
|
||||
"createDataFrame attempted Arrow optimization because "
|
||||
"'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has "
|
||||
"reached the error below and will not continue because automatic "
|
||||
"fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' "
|
||||
"has been set to false.\n %s" % _exception_message(e))
|
||||
warnings.warn(msg)
|
||||
raise
|
||||
data = self._convert_from_pandas(data, schema, timezone)
|
||||
# Create a DataFrame from pandas DataFrame.
|
||||
return super(SparkSession, self).createDataFrame(
|
||||
data, schema, verifySchema, samplingRatio)
|
||||
return self._create_dataframe(data, schema, verifySchema, samplingRatio)
|
||||
|
||||
def _create_dataframe(self, data, schema, verifySchema, samplingRatio):
|
||||
if isinstance(schema, StructType):
|
||||
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
|
||||
self.assertFalse(pdf_ny.equals(pdf_la))
|
||||
|
||||
from pyspark.sql.types import _check_series_convert_timestamps_local_tz
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz
|
||||
pdf_la_corrected = pdf_la.copy()
|
||||
for field in self.schema:
|
||||
if isinstance(field.dataType, TimestampType):
|
||||
|
@ -311,7 +311,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertTrue(pdf.equals(pdf_copy))
|
||||
|
||||
def test_schema_conversion_roundtrip(self):
|
||||
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
|
||||
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
|
||||
arrow_schema = to_arrow_schema(self.schema)
|
||||
schema_rt = from_arrow_schema(arrow_schema)
|
||||
self.assertEquals(self.schema, schema_rt)
|
||||
|
|
|
@ -1608,267 +1608,6 @@ register_input_converter(DatetimeConverter())
|
|||
register_input_converter(DateConverter())
|
||||
|
||||
|
||||
def to_arrow_type(dt):
|
||||
""" Convert Spark data type to pyarrow type
|
||||
"""
|
||||
import pyarrow as pa
|
||||
if type(dt) == BooleanType:
|
||||
arrow_type = pa.bool_()
|
||||
elif type(dt) == ByteType:
|
||||
arrow_type = pa.int8()
|
||||
elif type(dt) == ShortType:
|
||||
arrow_type = pa.int16()
|
||||
elif type(dt) == IntegerType:
|
||||
arrow_type = pa.int32()
|
||||
elif type(dt) == LongType:
|
||||
arrow_type = pa.int64()
|
||||
elif type(dt) == FloatType:
|
||||
arrow_type = pa.float32()
|
||||
elif type(dt) == DoubleType:
|
||||
arrow_type = pa.float64()
|
||||
elif type(dt) == DecimalType:
|
||||
arrow_type = pa.decimal128(dt.precision, dt.scale)
|
||||
elif type(dt) == StringType:
|
||||
arrow_type = pa.string()
|
||||
elif type(dt) == BinaryType:
|
||||
arrow_type = pa.binary()
|
||||
elif type(dt) == DateType:
|
||||
arrow_type = pa.date32()
|
||||
elif type(dt) == TimestampType:
|
||||
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
|
||||
arrow_type = pa.timestamp('us', tz='UTC')
|
||||
elif type(dt) == ArrayType:
|
||||
if type(dt.elementType) in [StructType, TimestampType]:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
arrow_type = pa.list_(to_arrow_type(dt.elementType))
|
||||
elif type(dt) == StructType:
|
||||
if any(type(field.dataType) == StructType for field in dt):
|
||||
raise TypeError("Nested StructType not supported in conversion to Arrow")
|
||||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
|
||||
for field in dt]
|
||||
arrow_type = pa.struct(fields)
|
||||
else:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
return arrow_type
|
||||
|
||||
|
||||
def to_arrow_schema(schema):
|
||||
""" Convert a schema from Spark to Arrow
|
||||
"""
|
||||
import pyarrow as pa
|
||||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
|
||||
for field in schema]
|
||||
return pa.schema(fields)
|
||||
|
||||
|
||||
def from_arrow_type(at):
|
||||
""" Convert pyarrow type to Spark data type.
|
||||
"""
|
||||
import pyarrow.types as types
|
||||
if types.is_boolean(at):
|
||||
spark_type = BooleanType()
|
||||
elif types.is_int8(at):
|
||||
spark_type = ByteType()
|
||||
elif types.is_int16(at):
|
||||
spark_type = ShortType()
|
||||
elif types.is_int32(at):
|
||||
spark_type = IntegerType()
|
||||
elif types.is_int64(at):
|
||||
spark_type = LongType()
|
||||
elif types.is_float32(at):
|
||||
spark_type = FloatType()
|
||||
elif types.is_float64(at):
|
||||
spark_type = DoubleType()
|
||||
elif types.is_decimal(at):
|
||||
spark_type = DecimalType(precision=at.precision, scale=at.scale)
|
||||
elif types.is_string(at):
|
||||
spark_type = StringType()
|
||||
elif types.is_binary(at):
|
||||
spark_type = BinaryType()
|
||||
elif types.is_date32(at):
|
||||
spark_type = DateType()
|
||||
elif types.is_timestamp(at):
|
||||
spark_type = TimestampType()
|
||||
elif types.is_list(at):
|
||||
if types.is_timestamp(at.value_type):
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
spark_type = ArrayType(from_arrow_type(at.value_type))
|
||||
elif types.is_struct(at):
|
||||
if any(types.is_struct(field.type) for field in at):
|
||||
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
|
||||
return StructType(
|
||||
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
for field in at])
|
||||
else:
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
return spark_type
|
||||
|
||||
|
||||
def from_arrow_schema(arrow_schema):
|
||||
""" Convert schema from Arrow to Spark.
|
||||
"""
|
||||
return StructType(
|
||||
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
for field in arrow_schema])
|
||||
|
||||
|
||||
def _get_local_timezone():
|
||||
""" Get local timezone using pytz with environment variable, or dateutil.
|
||||
|
||||
If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
|
||||
string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
|
||||
it reads system configuration to know the system local timezone.
|
||||
|
||||
See also:
|
||||
- https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
|
||||
- https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
|
||||
"""
|
||||
import os
|
||||
return os.environ.get('TZ', 'dateutil/:')
|
||||
|
||||
|
||||
def _check_series_localize_timestamps(s, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
|
||||
|
||||
If the input series is not a timestamp series, then the same series is returned. If the input
|
||||
series is a timestamp series, then a converted series is returned.
|
||||
|
||||
:param s: pandas.Series
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.Series that have been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64tz_dtype
|
||||
tz = timezone or _get_local_timezone()
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert(tz).dt.tz_localize(None)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param pdf: pandas.DataFrame
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
for column, series in pdf.iteritems():
|
||||
pdf[column] = _check_series_localize_timestamps(series, timezone)
|
||||
return pdf
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_internal(s, timezone):
|
||||
"""
|
||||
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
|
||||
Spark internal storage
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64_dtype(s.dtype):
|
||||
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
|
||||
# timestamp is during the hour when the clock is adjusted backward during due to
|
||||
# daylight saving time (dst).
|
||||
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
|
||||
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
|
||||
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
|
||||
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
|
||||
#
|
||||
# Here we explicit choose to use standard time. This matches the default behavior of
|
||||
# pytz.
|
||||
#
|
||||
# Here are some code to help understand this behavior:
|
||||
# >>> import datetime
|
||||
# >>> import pandas as pd
|
||||
# >>> import pytz
|
||||
# >>>
|
||||
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
|
||||
# >>> ts = pd.Series([t])
|
||||
# >>> tz = pytz.timezone('America/New_York')
|
||||
# >>>
|
||||
# >>> ts.dt.tz_localize(tz, ambiguous=True)
|
||||
# 0 2015-11-01 01:30:00-04:00
|
||||
# dtype: datetime64[ns, America/New_York]
|
||||
# >>>
|
||||
# >>> ts.dt.tz_localize(tz, ambiguous=False)
|
||||
# 0 2015-11-01 01:30:00-05:00
|
||||
# dtype: datetime64[ns, America/New_York]
|
||||
# >>>
|
||||
# >>> str(tz.localize(t))
|
||||
# '2015-11-01 01:30:00-05:00'
|
||||
tz = timezone or _get_local_timezone()
|
||||
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
|
||||
elif is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert('UTC')
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param from_timezone: the timezone to convert from. if None then use local timezone
|
||||
:param to_timezone: the timezone to convert to. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
|
||||
from_tz = from_timezone or _get_local_timezone()
|
||||
to_tz = to_timezone or _get_local_timezone()
|
||||
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
|
||||
if is_datetime64tz_dtype(s.dtype):
|
||||
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
|
||||
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
|
||||
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
|
||||
return s.apply(
|
||||
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
|
||||
if ts is not pd.NaT else pd.NaT)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_local_tz(s, timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert to. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
return _check_series_convert_timestamps_localize(s, None, timezone)
|
||||
|
||||
|
||||
def _check_series_convert_timestamps_tz_local(s, timezone):
|
||||
"""
|
||||
Convert timestamp to timezone-naive in the specified timezone or local timezone
|
||||
|
||||
:param s: a pandas.Series
|
||||
:param timezone: the timezone to convert from. if None then use local timezone
|
||||
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
return _check_series_convert_timestamps_localize(s, timezone, None)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
|
|
|
@ -23,8 +23,8 @@ import sys
|
|||
from pyspark import SparkContext, since
|
||||
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
|
||||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
|
||||
to_arrow_type, to_arrow_schema
|
||||
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
|
||||
from pyspark.sql.pandas.types import to_arrow_type
|
||||
from pyspark.util import _get_argspec
|
||||
|
||||
__all__ = ["UDFRegistration"]
|
||||
|
@ -46,7 +46,7 @@ def _create_udf(f, returnType, evalType):
|
|||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
|
||||
|
||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
|
||||
require_minimum_pyarrow_version()
|
||||
|
||||
argspec = _get_argspec(f)
|
||||
|
|
|
@ -136,50 +136,6 @@ def toJArray(gateway, jtype, arr):
|
|||
return jarr
|
||||
|
||||
|
||||
def require_minimum_pandas_version():
|
||||
""" Raise ImportError if minimum version of Pandas is not installed
|
||||
"""
|
||||
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
|
||||
minimum_pandas_version = "0.23.2"
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
try:
|
||||
import pandas
|
||||
have_pandas = True
|
||||
except ImportError:
|
||||
have_pandas = False
|
||||
if not have_pandas:
|
||||
raise ImportError("Pandas >= %s must be installed; however, "
|
||||
"it was not found." % minimum_pandas_version)
|
||||
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
|
||||
raise ImportError("Pandas >= %s must be installed; however, "
|
||||
"your version was %s." % (minimum_pandas_version, pandas.__version__))
|
||||
|
||||
|
||||
def require_minimum_pyarrow_version():
|
||||
""" Raise ImportError if minimum version of pyarrow is not installed
|
||||
"""
|
||||
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
|
||||
minimum_pyarrow_version = "0.15.1"
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import os
|
||||
try:
|
||||
import pyarrow
|
||||
have_arrow = True
|
||||
except ImportError:
|
||||
have_arrow = False
|
||||
if not have_arrow:
|
||||
raise ImportError("PyArrow >= %s must be installed; however, "
|
||||
"it was not found." % minimum_pyarrow_version)
|
||||
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
|
||||
raise ImportError("PyArrow >= %s must be installed; however, "
|
||||
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
|
||||
if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1":
|
||||
raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, "
|
||||
"please unset ARROW_PRE_0_15_IPC_FORMAT")
|
||||
|
||||
|
||||
def require_test_compiled():
|
||||
""" Raise Exception if test classes are not compiled
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,7 @@ from pyspark.util import _exception_message
|
|||
|
||||
pandas_requirement_message = None
|
||||
try:
|
||||
from pyspark.sql.utils import require_minimum_pandas_version
|
||||
from pyspark.sql.pandas.utils import require_minimum_pandas_version
|
||||
require_minimum_pandas_version()
|
||||
except ImportError as e:
|
||||
# If Pandas version requirement is not satisfied, skip related tests.
|
||||
|
@ -37,7 +37,7 @@ except ImportError as e:
|
|||
|
||||
pyarrow_requirement_message = None
|
||||
try:
|
||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
|
||||
require_minimum_pyarrow_version()
|
||||
except ImportError as e:
|
||||
# If Arrow version requirement is not satisfied, skip related tests.
|
||||
|
|
|
@ -39,8 +39,10 @@ from pyspark.resourceinformation import ResourceInformation
|
|||
from pyspark.rdd import PythonEvalType
|
||||
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
|
||||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
|
||||
BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
|
||||
from pyspark.sql.types import to_arrow_type, StructType
|
||||
BatchedSerializer
|
||||
from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer
|
||||
from pyspark.sql.pandas.types import to_arrow_type
|
||||
from pyspark.sql.types import StructType
|
||||
from pyspark.util import _get_argspec, fail_on_stopiteration
|
||||
from pyspark import shuffle
|
||||
|
||||
|
|
|
@ -179,6 +179,8 @@ try:
|
|||
'pyspark.ml.linalg',
|
||||
'pyspark.ml.param',
|
||||
'pyspark.sql',
|
||||
'pyspark.sql.avro',
|
||||
'pyspark.sql.pandas',
|
||||
'pyspark.streaming',
|
||||
'pyspark.bin',
|
||||
'pyspark.sbin',
|
||||
|
|
|
@ -105,7 +105,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
|
|||
Seq(
|
||||
pythonExec,
|
||||
"-c",
|
||||
"from pyspark.sql.utils import require_minimum_pandas_version;" +
|
||||
"from pyspark.sql.pandas.utils import require_minimum_pandas_version;" +
|
||||
"require_minimum_pandas_version()"),
|
||||
None,
|
||||
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
|
||||
|
@ -117,7 +117,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
|
|||
Seq(
|
||||
pythonExec,
|
||||
"-c",
|
||||
"from pyspark.sql.utils import require_minimum_pyarrow_version;" +
|
||||
"from pyspark.sql.pandas.utils import require_minimum_pyarrow_version;" +
|
||||
"require_minimum_pyarrow_version()"),
|
||||
None,
|
||||
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
|
||||
|
|
Loading…
Reference in a new issue