diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 41793593e9..9a9da1f3f8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -362,6 +362,13 @@ pyspark_sql = Module( "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.avro.functions", + "pyspark.sql.pandas.conversion", + "pyspark.sql.pandas.map_ops", + "pyspark.sql.pandas.functions", + "pyspark.sql.pandas.group_ops", + "pyspark.sql.pandas.types", + "pyspark.sql.pandas.serializers", + "pyspark.sql.pandas.utils", # unittests "pyspark.sql.tests.test_arrow", "pyspark.sql.tests.test_catalog", diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index d5a3173ff9..1c983172d3 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -24,7 +24,7 @@ Run with: from __future__ import print_function from pyspark.sql import SparkSession -from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version +from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version require_minimum_pandas_version() require_minimum_pyarrow_version() diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 5da7b44a95..b69562e845 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -7,6 +7,7 @@ Module Context .. automodule:: pyspark.sql :members: :undoc-members: + :inherited-members: :exclude-members: builder .. We need `exclude-members` to prevent default description generations as a workaround for old Sphinx (< 1.6.6). diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 994e42e238..49b7cb4546 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,248 +185,6 @@ class FramedSerializer(Serializer): raise NotImplementedError -class ArrowCollectSerializer(Serializer): - """ - Deserialize a stream of batches followed by batch order information. Used in - DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM. - """ - - def __init__(self): - self.serializer = ArrowStreamSerializer() - - def dump_stream(self, iterator, stream): - return self.serializer.dump_stream(iterator, stream) - - def load_stream(self, stream): - """ - Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields - a list of indices that can be used to put the RecordBatches in the correct order. - """ - # load the batches - for batch in self.serializer.load_stream(stream): - yield batch - - # load the batch order indices or propagate any error that occurred in the JVM - num = read_int(stream) - if num == -1: - error_msg = UTF8Deserializer().loads(stream) - raise RuntimeError("An error occurred while calling " - "ArrowCollectSerializer.load_stream: {}".format(error_msg)) - batch_order = [] - for i in xrange(num): - index = read_int(stream) - batch_order.append(index) - yield batch_order - - def __repr__(self): - return "ArrowCollectSerializer(%s)" % self.serializer - - -class ArrowStreamSerializer(Serializer): - """ - Serializes Arrow record batches as a stream. - """ - - def dump_stream(self, iterator, stream): - import pyarrow as pa - writer = None - try: - for batch in iterator: - if writer is None: - writer = pa.RecordBatchStreamWriter(stream, batch.schema) - writer.write_batch(batch) - finally: - if writer is not None: - writer.close() - - def load_stream(self, stream): - import pyarrow as pa - reader = pa.ipc.open_stream(stream) - for batch in reader: - yield batch - - def __repr__(self): - return "ArrowStreamSerializer" - - -class ArrowStreamPandasSerializer(ArrowStreamSerializer): - """ - Serializes Pandas.Series as Arrow data with Arrow streaming format. - - :param timezone: A timezone to respect when handling timestamp values - :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation - :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name - """ - - def __init__(self, timezone, safecheck, assign_cols_by_name): - super(ArrowStreamPandasSerializer, self).__init__() - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name - - def arrow_to_pandas(self, arrow_column): - from pyspark.sql.types import _check_series_localize_timestamps - - # If the given column is a date type column, creates a series of datetime.date directly - # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by - # datetime64[ns] type handling. - s = arrow_column.to_pandas(date_as_object=True) - - s = _check_series_localize_timestamps(s, self._timezone) - return s - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series or list of Series, - with optional type. - - :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) - :return: Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - from pyspark.sql.types import _check_series_convert_timestamps_internal - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - def create_array(s, t): - mask = s.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - if t is not None and pa.types.is_timestamp(t): - s = _check_series_convert_timestamps_internal(s, self._timezone) - try: - array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) - except pa.ArrowException as e: - error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ - "Array (%s). It can be caused by overflows or other unsafe " + \ - "conversions warned by Arrow. Arrow safe type check can be " + \ - "disabled by using SQL config " + \ - "`spark.sql.execution.pandas.arrowSafeTypeConversion`." - raise RuntimeError(error_msg % (s.dtype, t), e) - return array - - arrs = [] - for s, t in series: - if t is not None and pa.types.is_struct(t): - if not isinstance(s, pd.DataFrame): - raise ValueError("A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s))) - - # Input partition and result pandas.DataFrame empty, make empty Arrays with struct - if len(s) == 0 and len(s.columns) == 0: - arrs_names = [(pa.array([], type=field.type), field.name) for field in t] - # Assign result columns by schema name if user labeled with strings - elif self._assign_cols_by_name and any(isinstance(name, basestring) - for name in s.columns): - arrs_names = [(create_array(s[field.name], field.type), field.name) - for field in t] - # Assign result columns by position - else: - arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) - for i, field in enumerate(t)] - - struct_arrs, struct_names = zip(*arrs_names) - arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) - else: - arrs.append(create_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - - def dump_stream(self, iterator, stream): - """ - Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or - a list of series accompanied by an optional pyarrow type to coerce the data to. - """ - batches = (self._create_batch(series) for series in iterator) - super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) - - def load_stream(self, stream): - """ - Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. - """ - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) - import pyarrow as pa - for batch in batches: - yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] - - def __repr__(self): - return "ArrowStreamPandasSerializer" - - -class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): - """ - Serializer used by Python worker to evaluate Pandas UDFs - """ - - def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): - super(ArrowStreamPandasUDFSerializer, self) \ - .__init__(timezone, safecheck, assign_cols_by_name) - self._df_for_struct = df_for_struct - - def arrow_to_pandas(self, arrow_column): - import pyarrow.types as types - - if self._df_for_struct and types.is_struct(arrow_column.type): - import pandas as pd - series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column) - .rename(field.name) - for column, field in zip(arrow_column.flatten(), arrow_column.type)] - s = pd.concat(series, axis=1) - else: - s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) - return s - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - - def init_stream_yield_batches(): - should_write_start_length = True - for series in iterator: - batch = self._create_batch(series) - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch - - return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) - - def __repr__(self): - return "ArrowStreamPandasUDFSerializer" - - -class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): - - def load_stream(self, stream): - """ - Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two - lists of pandas.Series. - """ - import pyarrow as pa - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == 2: - batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] - batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] - yield ( - [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()], - [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()] - ) - - elif dataframes_in_group != 0: - raise ValueError( - 'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group)) - - class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index ba4c4feec7..0a8d71c12e 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -51,12 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec -from pyspark.sql.cogroup import CoGroupedData +from pyspark.sql.pandas.group_ops import PandasCogroupedOps __all__ = [ 'SparkSession', 'SQLContext', 'UDFRegistration', 'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', - 'DataFrameReader', 'DataFrameWriter', 'CoGroupedData' + 'DataFrameReader', 'DataFrameWriter', 'PandasCogroupedOps' ] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2fa90d6788..8f4454a08d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -31,8 +31,8 @@ import warnings from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, \ - ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ + ignore_unicode_prefix +from pyspark.serializers import BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -40,14 +40,14 @@ from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.streaming import DataStreamWriter -from pyspark.sql.types import IntegralType from pyspark.sql.types import * -from pyspark.util import _exception_message +from pyspark.sql.pandas.conversion import PandasConversionMixin +from pyspark.sql.pandas.map_ops import PandasMapOpsMixin __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] -class DataFrame(object): +class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """A distributed collection of data grouped into named columns. A :class:`DataFrame` is equivalent to a relational table in Spark SQL, @@ -2135,193 +2135,6 @@ class DataFrame(object): "should have been DataFrame." % type(result) return result - @since(1.3) - def toPandas(self): - """ - Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. - - This is only available if Pandas is installed and available. - - .. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is - expected to be small, as all the data is loaded into the driver's memory. - - .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental. - - >>> df.toPandas() # doctest: +SKIP - age name - 0 2 Alice - 1 5 Bob - """ - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - import numpy as np - import pandas as pd - - if self.sql_ctx._conf.pandasRespectSessionTimeZone(): - timezone = self.sql_ctx._conf.sessionLocalTimeZone() - else: - timezone = None - - if self.sql_ctx._conf.arrowPySparkEnabled(): - use_arrow = True - try: - from pyspark.sql.types import to_arrow_schema - from pyspark.sql.utils import require_minimum_pyarrow_version - - require_minimum_pyarrow_version() - to_arrow_schema(self.schema) - except Exception as e: - - if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): - msg = ( - "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "Attempting non-optimization as " - "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " - "true." % _exception_message(e)) - warnings.warn(msg) - use_arrow = False - else: - msg = ( - "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " - "reached the error below and will not continue because automatic fallback " - "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to " - "false.\n %s" % _exception_message(e)) - warnings.warn(msg) - raise - - # Try to use Arrow optimization when the schema is supported and the required version - # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. - if use_arrow: - try: - from pyspark.sql.types import _check_dataframe_localize_timestamps - import pyarrow - batches = self._collectAsArrow() - if len(batches) > 0: - table = pyarrow.Table.from_batches(batches) - # Pandas DataFrame created from PyArrow uses datetime64[ns] for date type - # values, but we should use datetime.date to match the behavior with when - # Arrow optimization is disabled. - pdf = table.to_pandas(date_as_object=True) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) - except Exception as e: - # We might have to allow fallback here as well but multiple Spark jobs can - # be executed. So, simply fail in this case for now. - msg = ( - "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " - "reached the error below and can not continue. Note that " - "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an " - "effect on failures in the middle of " - "computation.\n %s" % _exception_message(e)) - warnings.warn(msg) - raise - - # Below is toPandas without Arrow optimization. - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - # Ensure we fall back to nullable numpy types, even when whole column is null: - if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any(): - dtype[field.name] = np.float64 - if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any(): - dtype[field.name] = np.object - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf - - def mapInPandas(self, udf): - """ - Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined - function and returns the result as a :class:`DataFrame`. - - The user-defined function should take an iterator of `pandas.DataFrame`\\s and return - another iterator of `pandas.DataFrame`\\s. All columns are passed - together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the - returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. - Each `pandas.DataFrame` size can be controlled by - `spark.sql.execution.arrow.maxRecordsPerBatch`. - Its schema must match the returnType of the Pandas user-defined function. - - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame([(1, 21), (2, 30)], - ... ("id", "age")) # doctest: +SKIP - >>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP - ... def filter_func(batch_iter): - ... for pdf in batch_iter: - ... yield pdf[pdf.id == 1] - >>> df.mapInPandas(filter_func).show() # doctest: +SKIP - +---+---+ - | id|age| - +---+---+ - | 1| 21| - +---+---+ - - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` - - """ - # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: - raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "MAP_ITER.") - - udf_column = udf(*[self[col] for col in self.columns]) - jdf = self._jdf.mapInPandas(udf_column._jc.expr()) - return DataFrame(jdf, self.sql_ctx) - - def _collectAsArrow(self): - """ - Returns all records as a list of ArrowRecordBatches, pyarrow must be installed - and available on driver and worker Python environments. - - .. note:: Experimental. - """ - with SCCallSiteSync(self._sc) as css: - port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython() - - # Collect list of un-ordered batches where last element is a list of correct order indices - try: - results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer())) - finally: - # Join serving thread and raise any exceptions from collectAsArrowToPython - jsocket_auth_server.getResult() - - # Separate RecordBatches from batch order indices in results - batches = results[:-1] - batch_order = results[-1] - - # Re-order the batch list using the correct order - return [batches[i] for i in batch_order] - ########################################################################################## # Pandas compatibility ########################################################################################## @@ -2349,33 +2162,6 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) -def _to_corrected_pandas_type(dt): - """ - When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type may be - wrong. This method gets the corrected data type for Pandas if that type may be inferred - uncorrectly. - """ - import numpy as np - if type(dt) == ByteType: - return np.int8 - elif type(dt) == ShortType: - return np.int16 - elif type(dt) == IntegerType: - return np.int32 - elif type(dt) == LongType: - return np.int64 - elif type(dt) == FloatType: - return np.float32 - elif type(dt) == DoubleType: - return np.float64 - elif type(dt) == BooleanType: - return np.bool - elif type(dt) == TimestampType: - return np.datetime64 - else: - return None - - class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a8d4732237..176729eb51 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -36,6 +36,8 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf +# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 +from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType from pyspark.sql.utils import to_str # Note to developers: all of PySpark functions here take string as column names whenever possible. @@ -2814,22 +2816,6 @@ def from_csv(col, schema, options={}): # ---------------------------- User Defined Function ---------------------------------- -class PandasUDFType(object): - """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. - """ - SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF - - SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - - GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - - COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF - - GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF - - MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - - @since(1.3) def udf(f=None, returnType=StringType()): """Creates a user defined function (UDF). @@ -2917,483 +2903,6 @@ def udf(f=None, returnType=StringType()): evalType=PythonEvalType.SQL_BATCHED_UDF) -@since(2.3) -def pandas_udf(f=None, returnType=None, functionType=None): - """ - Creates a vectorized user defined function (UDF). - - :param f: user-defined function. A python function if used as a standalone function - :param returnType: the return type of the user-defined function. The value can be either a - :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. - Default: SCALAR. - - The function type of the UDF can be one of the following: - - 1. SCALAR - - A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. - If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - - :class:`MapType`, nested :class:`StructType` are currently not supported as output types. - - Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP - >>> @pandas_udf(StringType()) # doctest: +SKIP - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], - ... ("id", "name", "age")) # doctest: +SKIP - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ - >>> @pandas_udf("first string, last string") # doctest: +SKIP - ... def split_expand(n): - ... return n.str.split(expand=True) - >>> df.select(split_expand("name")).show() # doctest: +SKIP - +------------------+ - |split_expand(name)| - +------------------+ - | [John, Doe]| - +------------------+ - - .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input - column, but is the length of an internal batch used for each call to the function. - Therefore, this can be used, for example, to ensure the length of each returned - `pandas.Series`, and can not be used as the column length. - - 2. SCALAR_ITER - - A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the - wrapped Python function takes an iterator of batches as input instead of a single batch and, - instead of returning a single output batch, it yields output batches or explicitly returns an - generator or an iterator of output batches. - It is useful when the UDF execution requires initializing some state, e.g., loading a machine - learning model file to apply inference to every input batch. - - .. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all - batches from one partition, although it is currently implemented this way. - Your code shall not rely on this behavior because it might change in the future for - further optimization, e.g., one invocation processes multiple partitions. - - Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType - >>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP - >>> df = spark.createDataFrame(pdf) # doctest: +SKIP - - When the UDF is called with a single column that is not `StructType`, the input to the - underlying function is an iterator of `pd.Series`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_one(batch_iter): - ... for x in batch_iter: - ... yield x + 1 - ... - >>> df.select(plus_one(col("x"))).show() # doctest: +SKIP - +-----------+ - |plus_one(x)| - +-----------+ - | 2| - | 3| - | 4| - +-----------+ - - When the UDF is called with more than one columns, the input to the underlying function is an - iterator of `pd.Series` tuple. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_cols(batch_iter): - ... for a, b in batch_iter: - ... yield a * b - ... - >>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP - +-----------------------+ - |multiply_two_cols(x, x)| - +-----------------------+ - | 1| - | 4| - | 9| - +-----------------------+ - - When the UDF is called with a single column that is `StructType`, the input to the underlying - function is an iterator of `pd.DataFrame`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_nested_cols(pdf_iter): - ... for pdf in pdf_iter: - ... yield pdf["a"] * pdf["b"] - ... - >>> df.select( - ... multiply_two_nested_cols( - ... struct(col("x").alias("a"), col("x").alias("b")) - ... ).alias("y") - ... ).show() # doctest: +SKIP - +---+ - | y| - +---+ - | 1| - | 4| - | 9| - +---+ - - In the UDF, you can initialize some states before processing batches, wrap your code with - `try ... finally ...` or use context managers to ensure the release of resources at the end - or in case of early termination. - - >>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_y(batch_iter): - ... y = y_bc.value # initialize some state - ... try: - ... for x in batch_iter: - ... yield x + y - ... finally: - ... pass # release resources here, if any - ... - >>> df.select(plus_y(col("x"))).show() # doctest: +SKIP - +---------+ - |plus_y(x)| - +---------+ - | 2| - | 3| - | 4| - +---------+ - - 3. GROUPED_MAP - - A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` - The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match - the field names in the defined returnType schema if specified as strings, or match the - field data types by position if not strings, e.g. integer indices. - The length of the returned `pandas.DataFrame` can be arbitrary. - - Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def normalize(pdf): - ... v = pdf.v - ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.8320502943378437| - | 2|-0.2773500981126146| - | 2| 1.1094003924504583| - +---+-------------------+ - - Alternatively, the user can define a function that takes two arguments. - In this case, the grouping key(s) will be passed as the first argument and the data will - be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy - data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in - as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. - This is useful when the user does not want to hardcode grouping key(s) in the function. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def mean_udf(key, pdf): - ... # key is a tuple of one numpy.int64, which is the value - ... # of 'id' for the current group - ... return pd.DataFrame([key + (pdf.v.mean(),)]) - >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP - +---+---+ - | id| v| - +---+---+ - | 1|1.5| - | 2|6.0| - +---+---+ - >>> @pandas_udf( - ... "id long, `ceil(v / 2)` long, v double", - ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP - >>> def sum_udf(key, pdf): - ... # key is a tuple of two numpy.int64s, which is the values - ... # of 'id' and 'ceil(df.v / 2)' for the current group - ... return pd.DataFrame([key + (pdf.v.sum(),)]) - >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP - +---+-----------+----+ - | id|ceil(v / 2)| v| - +---+-----------+----+ - | 2| 5|10.0| - | 1| 1| 3.0| - | 2| 3| 5.0| - | 2| 2| 3.0| - +---+-----------+----+ - - .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is - recommended to explicitly index the columns by name to ensure the positions are correct, - or alternatively use an `OrderedDict`. - For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or - `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. - - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - - 4. GROUPED_AGG - - A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. - The returned scalar can be either a python primitive type, e.g., `int` or `float` - or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - - :class:`MapType` and :class:`StructType` are currently not supported as output types. - - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and - :class:`pyspark.sql.Window` - - This example shows using grouped aggregated UDFs with groupby: - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP - +---+-----------+ - | id|mean_udf(v)| - +---+-----------+ - | 1| 1.5| - | 2| 6.0| - +---+-----------+ - - This example shows using grouped aggregated UDFs as window functions. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql import Window - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> w = (Window.partitionBy('id') - ... .orderBy('v') - ... .rowsBetween(-1, 0)) - >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP - +---+----+------+ - | id| v|mean_v| - +---+----+------+ - | 1| 1.0| 1.0| - | 1| 2.0| 1.5| - | 2| 3.0| 3.0| - | 2| 5.0| 4.0| - | 2|10.0| 7.5| - +---+----+------+ - - .. note:: For performance reasons, the input series to window functions are not copied. - Therefore, mutating the input series is not allowed and will cause incorrect results. - For the same reason, users should also not rely on the index of the input series. - - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` - - 5. MAP_ITER - - A map iterator Pandas UDFs are used to transform data with an iterator of batches. - It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`. - - It can return the output of arbitrary length in contrast to the scalar Pandas UDF. - It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined - function and returns the result as a :class:`DataFrame`. - - The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another - iterator of `pandas.DataFrame`\\s. All columns are passed together as an - iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of - `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. - - >>> df = spark.createDataFrame([(1, 21), (2, 30)], - ... ("id", "age")) # doctest: +SKIP - >>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP - ... def filter_func(batch_iter): - ... for pdf in batch_iter: - ... yield pdf[pdf.id == 1] - >>> df.mapInPandas(filter_func).show() # doctest: +SKIP - +---+---+ - | id|age| - +---+---+ - | 1| 21| - +---+---+ - - 6. COGROUPED_MAP - - A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) -> - `pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema - of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` - must either match the field names in the defined `returnType` schema if specified as strings, - or match the field data types by position if not strings, e.g. integer indices. The length - of the returned `pandas.DataFrame` can be arbitrary. - - CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df1 = spark.createDataFrame( - ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], - ... ("time", "id", "v1")) - >>> df2 = spark.createDataFrame( - ... [(20000101, 1, "x"), (20000101, 2, "y")], - ... ("time", "id", "v2")) - >>> @pandas_udf("time int, id int, v1 double, v2 string", - ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP - ... def asof_join(l, r): - ... return pd.merge_asof(l, r, on="time", by="id") - >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP - +---------+---+---+---+ - | time| id| v1| v2| - +---------+---+---+---+ - | 20000101| 1|1.0| x| - | 20000102| 1|3.0| x| - | 20000101| 2|2.0| y| - | 20000102| 2|4.0| y| - +---------+---+---+---+ - - Alternatively, the user can define a function that takes three arguments. In this case, - the grouping key(s) will be passed as the first argument and the data will be passed as the - second and third arguments. The grouping key(s) will be passed as a tuple of numpy data - types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two - `pandas.DataFrame` containing all columns from the original Spark DataFrames. - >>> @pandas_udf("time int, id int, v1 double, v2 string", - ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP - ... def asof_join(k, l, r): - ... if k == (1,): - ... return pd.merge_asof(l, r, on="time", by="id") - ... else: - ... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2']) - >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP - +---------+---+---+---+ - | time| id| v1| v2| - +---------+---+---+---+ - | 20000101| 1|1.0| x| - | 20000102| 1|3.0| x| - +---------+---+---+---+ - - .. note:: The user-defined functions are considered deterministic by default. Due to - optimization, duplicate invocations may be eliminated or the function may even be invoked - more times than it is present in the query. If your function is not deterministic, call - `asNondeterministic` on the user defined function. E.g.: - - >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP - ... def random(v): - ... import numpy as np - ... import pandas as pd - ... return pd.Series(np.random.randn(len(v)) - >>> random = random.asNondeterministic() # doctest: +SKIP - - .. note:: The user-defined functions do not support conditional expressions or short circuiting - in boolean expressions and it ends up with being executed all internally. If the functions - can fail on special rows, the workaround is to incorporate the condition into the functions. - - .. note:: The user-defined functions do not take keyword arguments on the calling side. - - .. note:: The data type of returned `pandas.Series` from the user-defined functions should be - matched with defined returnType (see :meth:`types.to_arrow_type` and - :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do - conversion on returned data. The conversion is not guaranteed to be correct and results - should be checked for accuracy by users. - """ - - # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that - # are not yet visible to the user. Some of behaviors are buggy and might be changed in the near - # future. The table might have to be eventually documented externally. - # Please see SPARK-28132's PR to see the codes in order to generate the table below. - # - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa - # |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa - # | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa - # | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa - # | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa - # | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa - # | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa - # | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa - # | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa - # | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa - # | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa - # | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa - # | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa - # | array| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa - # | map| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa - # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa - # | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa - # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa - # - # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be - # used in `returnType`. - # Note: The values inside of the table are generated by `repr`. - # Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used. - # Note: Timezone is KST. - # Note: 'X' means it throws an exception during the conversion. - - # decorator @pandas_udf(returnType, functionType) - is_decorator = f is None or isinstance(f, (str, DataType)) - - if is_decorator: - # If DataType has been passed as a positional argument - # for decorator use it as a returnType - return_type = f or returnType - - if functionType is not None: - # @pandas_udf(dataType, functionType=functionType) - # @pandas_udf(returnType=dataType, functionType=functionType) - eval_type = functionType - elif returnType is not None and isinstance(returnType, int): - # @pandas_udf(dataType, functionType) - eval_type = returnType - else: - # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) - eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF - else: - return_type = returnType - - if functionType is not None: - eval_type = functionType - else: - eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF - - if return_type is None: - raise ValueError("Invalid returnType: returnType can not be None") - - if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]: - raise ValueError("Invalid functionType: " - "functionType must be one the values from PandasUDFType") - - if is_decorator: - return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) - else: - return _create_udf(f=f, returnType=return_type, evalType=eval_type) - - blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fcad641424..ac826bc64a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -18,11 +18,11 @@ import sys from pyspark import since -from pyspark.rdd import ignore_unicode_prefix, PythonEvalType +from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame +from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin from pyspark.sql.types import * -from pyspark.sql.cogroup import CoGroupedData __all__ = ["GroupedData"] @@ -47,7 +47,7 @@ def df_varargs_api(f): return _api -class GroupedData(object): +class GroupedData(PandasGroupedOpsMixin): """ A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. @@ -219,68 +219,6 @@ class GroupedData(object): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) - @since(3.0) - def cogroup(self, other): - """ - Cogroups this group with another group so that we can run cogrouped operations. - - See :class:`CoGroupedData` for the operations that can be run. - """ - return CoGroupedData(self, other) - - @since(2.3) - def apply(self, udf): - """ - Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result - as a `DataFrame`. - - The user-defined function should take a `pandas.DataFrame` and return another - `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame` are combined as a - :class:`DataFrame`. - - The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. - - .. note:: This function requires a full shuffle. All the data of a group will be loaded - into memory, so the user should be aware of the potential OOM risk if data is skewed - and certain groups are too large to fit in memory. - - :param udf: a grouped map user-defined function returned by - :func:`pyspark.sql.functions.pandas_udf`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def normalize(pdf): - ... v = pdf.v - ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.8320502943378437| - | 2|-0.2773500981126146| - | 2| 1.1094003924504583| - +---+-------------------+ - - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` - - """ - # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "GROUPED_MAP.") - df = self._df - udf_column = udf(*[df[col] for col in df.columns]) - jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) - return DataFrame(jdf, self.sql_ctx) - def _test(): import doctest diff --git a/python/pyspark/sql/pandas/__init__.py b/python/pyspark/sql/pandas/__init__.py new file mode 100644 index 0000000000..32a88e9b37 --- /dev/null +++ b/python/pyspark/sql/pandas/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +This package includes the internal APIs for PySpark about interoperability +between pandas, PySpark and PyArrow. This package should not be directly +imported and used. +""" diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py new file mode 100644 index 0000000000..1c957a1665 --- /dev/null +++ b/python/pyspark/sql/pandas/conversion.py @@ -0,0 +1,431 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import warnings +if sys.version >= '3': + basestring = unicode = str + xrange = range +else: + from itertools import izip as zip + +from pyspark import since +from pyspark.rdd import _load_from_socket +from pyspark.sql.pandas.serializers import ArrowCollectSerializer +from pyspark.sql.types import IntegralType +from pyspark.sql.types import * +from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import _exception_message + + +class PandasConversionMixin(object): + """ + Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame` + can use this class. + """ + + @since(1.3) + def toPandas(self): + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + + This is only available if Pandas is installed and available. + + .. note:: This method should only be used if the resulting Pandas's :class:`DataFrame` is + expected to be small, as all the data is loaded into the driver's memory. + + .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental. + + >>> df.toPandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, DataFrame) + + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + import numpy as np + import pandas as pd + + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() + else: + timezone = None + + if self.sql_ctx._conf.arrowPySparkEnabled(): + use_arrow = True + try: + from pyspark.sql.pandas.types import to_arrow_schema + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version + + require_minimum_pyarrow_version() + to_arrow_schema(self.schema) + except Exception as e: + + if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempting non-optimization as " + "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " + "reached the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.pandas.types import _check_dataframe_localize_timestamps + import pyarrow + batches = self._collect_as_arrow() + if len(batches) > 0: + table = pyarrow.Table.from_batches(batches) + # Pandas DataFrame created from PyArrow uses datetime64[ns] for date type + # values, but we should use datetime.date to match the behavior with when + # Arrow optimization is disabled. + pdf = table.to_pandas(date_as_object=True) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " + "reached the error below and can not continue. Note that " + "'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an " + "effect on failures in the middle of " + "computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = PandasConversionMixin._to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + # Ensure we fall back to nullable numpy types, even when whole column is null: + if isinstance(field.dataType, IntegralType) and pdf[field.name].isnull().any(): + dtype[field.name] = np.float64 + if isinstance(field.dataType, BooleanType) and pdf[field.name].isnull().any(): + dtype[field.name] = np.object + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz + for field in self.schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf + + @staticmethod + def _to_corrected_pandas_type(dt): + """ + When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type + may be wrong. This method gets the corrected data type for Pandas if that type may be + inferred incorrectly. + """ + import numpy as np + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == LongType: + return np.int64 + elif type(dt) == FloatType: + return np.float32 + elif type(dt) == DoubleType: + return np.float64 + elif type(dt) == BooleanType: + return np.bool + elif type(dt) == TimestampType: + return np.datetime64 + else: + return None + + def _collect_as_arrow(self): + """ + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. + + .. note:: Experimental. + """ + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, DataFrame) + + with SCCallSiteSync(self._sc): + port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython() + + # Collect list of un-ordered batches where last element is a list of correct order indices + try: + results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer())) + finally: + # Join serving thread and raise any exceptions from collectAsArrowToPython + jsocket_auth_server.getResult() + + # Separate RecordBatches from batch order indices in results + batches = results[:-1] + batch_order = results[-1] + + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] + + +class SparkConversionMixin(object): + """ + Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession` + can use this class. + """ + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): + from pyspark.sql import SparkSession + + assert isinstance(self, SparkSession) + + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + if self._wrapped._conf.pandasRespectSessionTimeZone(): + timezone = self._wrapped._conf.sessionLocalTimeZone() + else: + timezone = None + + # If no schema supplied by user then get the names of columns only + if schema is None: + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] + + if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0: + try: + return self._create_from_pandas_with_arrow(data, schema, timezone) + except Exception as e: + from pyspark.util import _exception_message + + if self._wrapped._conf.arrowPySparkFallbackEnabled(): + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempting non-optimization as " + "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " + "reached the error below and will not continue because automatic " + "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' " + "has been set to false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise + data = self._convert_from_pandas(data, schema, timezone) + return self._create_dataframe(data, schema, samplingRatio, samplingRatio) + + def _convert_from_pandas(self, pdf, schema, timezone): + """ + Convert a pandas.DataFrame to list of records that can be used to make a DataFrame + :return list of records + """ + from pyspark.sql import SparkSession + + assert isinstance(self, SparkSession) + + if timezone is not None: + from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local + copied = False + if isinstance(schema, StructType): + for field in schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s + else: + for column, series in pdf.iteritems(): + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s + + # Convert pandas.DataFrame to list of numpy records + np_records = pdf.to_records(index=False) + + # Check if any columns need to be fixed for Spark to infer properly + if len(np_records) > 0: + record_dtype = self._get_numpy_record_dtype(np_records[0]) + if record_dtype is not None: + return [r.astype(record_dtype).tolist() for r in np_records] + + # Convert list of numpy records to python lists + return [r.tolist() for r in np_records] + + def _get_numpy_record_dtype(self, rec): + """ + Used when converting a pandas.DataFrame to Spark using to_records(), this will correct + the dtypes of fields in a record so they can be properly loaded into Spark. + :param rec: a numpy record to check field dtypes + :return corrected dtype for a numpy.record or None if no correction needed + """ + import numpy as np + cur_dtypes = rec.dtype + col_names = cur_dtypes.names + record_type_list = [] + has_rec_fix = False + for i in xrange(len(cur_dtypes)): + curr_type = cur_dtypes[i] + # If type is a datetime64 timestamp, convert to microseconds + # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, + # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 + if curr_type == np.dtype('datetime64[ns]'): + curr_type = 'datetime64[us]' + has_rec_fix = True + record_type_list.append((str(col_names[i]), curr_type)) + return np.dtype(record_type_list) if has_rec_fix else None + + def _create_from_pandas_with_arrow(self, pdf, schema, timezone): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + + assert isinstance(self, SparkSession) + + from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + from pyspark.sql.types import TimestampType + from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type + from pyspark.sql.pandas.utils import require_minimum_pandas_version, \ + require_minimum_pyarrow_version + + require_minimum_pandas_version() + require_minimum_pyarrow_version() + + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + import pyarrow as pa + + # Create the Spark schema from list of names passed in with Arrow types + if isinstance(schema, (list, tuple)): + arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) + struct = StructType() + for name, field in zip(schema, arrow_schema): + struct.add(name, from_arrow_type(field.type), nullable=field.nullable) + schema = struct + + # Determine arrow types to coerce data when creating batches + if isinstance(schema, StructType): + arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] + elif isinstance(schema, DataType): + raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + else: + # Any timestamps must be coerced to be compatible with Spark + arrow_types = [to_arrow_type(TimestampType()) + if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None + for t in pdf.dtypes] + + # Slice the DataFrame to be batched + step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) + + # Create list of Arrow (columns, type) for serializer dump_stream + arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] + for pdf_slice in pdf_slices] + + jsqlContext = self._wrapped._jsqlContext + + safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) + + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) + + def create_RDD_server(): + return self._jvm.ArrowRDDServer(jsqlContext) + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) + jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.pandas.conversion + globs = pyspark.sql.pandas.conversion.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.pandas.conversion tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.pandas.conversion, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py new file mode 100644 index 0000000000..26241dbe68 --- /dev/null +++ b/python/pyspark/sql/pandas/functions.py @@ -0,0 +1,539 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import sys + +from pyspark import since +from pyspark.rdd import PythonEvalType +from pyspark.sql.types import DataType +from pyspark.sql.udf import _create_udf + + +class PandasUDFType(object): + """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. + """ + SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF + + SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + + GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + + COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF + + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + + MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + + +@since(2.3) +def pandas_udf(f=None, returnType=None, functionType=None): + """ + Creates a vectorized user defined function (UDF). + + :param f: user-defined function. A python function if used as a standalone function + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. + Default: SCALAR. + + The function type of the UDF can be one of the following: + + 1. SCALAR + + A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. + + :class:`MapType`, nested :class:`StructType` are currently not supported as output types. + + Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP + >>> @pandas_udf(StringType()) # doctest: +SKIP + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], + ... ("id", "name", "age")) # doctest: +SKIP + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + >>> @pandas_udf("first string, last string") # doctest: +SKIP + ... def split_expand(n): + ... return n.str.split(expand=True) + >>> df.select(split_expand("name")).show() # doctest: +SKIP + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ + + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + + 2. SCALAR_ITER + + A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the + wrapped Python function takes an iterator of batches as input instead of a single batch and, + instead of returning a single output batch, it yields output batches or explicitly returns an + generator or an iterator of output batches. + It is useful when the UDF execution requires initializing some state, e.g., loading a machine + learning model file to apply inference to every input batch. + + .. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all + batches from one partition, although it is currently implemented this way. + Your code shall not rely on this behavior because it might change in the future for + further optimization, e.g., one invocation processes multiple partitions. + + Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType + >>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP + >>> df = spark.createDataFrame(pdf) # doctest: +SKIP + + When the UDF is called with a single column that is not `StructType`, the input to the + underlying function is an iterator of `pd.Series`. + + >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP + ... def plus_one(batch_iter): + ... for x in batch_iter: + ... yield x + 1 + ... + >>> df.select(plus_one(col("x"))).show() # doctest: +SKIP + +-----------+ + |plus_one(x)| + +-----------+ + | 2| + | 3| + | 4| + +-----------+ + + When the UDF is called with more than one columns, the input to the underlying function is an + iterator of `pd.Series` tuple. + + >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP + ... def multiply_two_cols(batch_iter): + ... for a, b in batch_iter: + ... yield a * b + ... + >>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP + +-----------------------+ + |multiply_two_cols(x, x)| + +-----------------------+ + | 1| + | 4| + | 9| + +-----------------------+ + + When the UDF is called with a single column that is `StructType`, the input to the underlying + function is an iterator of `pd.DataFrame`. + + >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP + ... def multiply_two_nested_cols(pdf_iter): + ... for pdf in pdf_iter: + ... yield pdf["a"] * pdf["b"] + ... + >>> df.select( + ... multiply_two_nested_cols( + ... struct(col("x").alias("a"), col("x").alias("b")) + ... ).alias("y") + ... ).show() # doctest: +SKIP + +---+ + | y| + +---+ + | 1| + | 4| + | 9| + +---+ + + In the UDF, you can initialize some states before processing batches, wrap your code with + `try ... finally ...` or use context managers to ensure the release of resources at the end + or in case of early termination. + + >>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP + >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP + ... def plus_y(batch_iter): + ... y = y_bc.value # initialize some state + ... try: + ... for x in batch_iter: + ... yield x + y + ... finally: + ... pass # release resources here, if any + ... + >>> df.select(plus_y(col("x"))).show() # doctest: +SKIP + +---------+ + |plus_y(x)| + +---------+ + | 2| + | 3| + | 4| + +---------+ + + 3. GROUPED_MAP + + A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. + + Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key(s) in the function. + + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + >>> @pandas_udf( + ... "id long, `ceil(v / 2)` long, v double", + ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP + >>> def sum_udf(key, pdf): + ... # key is a tuple of two numpy.int64s, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... return pd.DataFrame([key + (pdf.v.sum(),)]) + >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ + + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + 4. GROUPED_AGG + + A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`MapType` and :class:`StructType` are currently not supported as output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + This example shows using grouped aggregated UDFs as window functions. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = (Window.partitionBy('id') + ... .orderBy('v') + ... .rowsBetween(-1, 0)) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.0| + | 1| 2.0| 1.5| + | 2| 3.0| 3.0| + | 2| 5.0| 4.0| + | 2|10.0| 7.5| + +---+----+------+ + + .. note:: For performance reasons, the input series to window functions are not copied. + Therefore, mutating the input series is not allowed and will cause incorrect results. + For the same reason, users should also not rely on the index of the input series. + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` + + 5. MAP_ITER + + A map iterator Pandas UDFs are used to transform data with an iterator of batches. + It can be used with :meth:`pyspark.sql.DataFrame.mapInPandas`. + + It can return the output of arbitrary length in contrast to the scalar Pandas UDF. + It maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined + function and returns the result as a :class:`DataFrame`. + + The user-defined function should take an iterator of `pandas.DataFrame`\\s and return another + iterator of `pandas.DataFrame`\\s. All columns are passed together as an + iterator of `pandas.DataFrame`\\s to the user-defined function and the returned iterator of + `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. + + >>> df = spark.createDataFrame([(1, 21), (2, 30)], + ... ("id", "age")) # doctest: +SKIP + >>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP + ... def filter_func(batch_iter): + ... for pdf in batch_iter: + ... yield pdf[pdf.id == 1] + >>> df.mapInPandas(filter_func).show() # doctest: +SKIP + +---+---+ + | id|age| + +---+---+ + | 1| 21| + +---+---+ + + 6. COGROUPED_MAP + + A cogrouped map UDF defines transformation: (`pandas.DataFrame`, `pandas.DataFrame`) -> + `pandas.DataFrame`. The `returnType` should be a :class:`StructType` describing the schema + of the returned `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` + must either match the field names in the defined `returnType` schema if specified as strings, + or match the field data types by position if not strings, e.g. integer indices. The length + of the returned `pandas.DataFrame` can be arbitrary. + + CoGrouped map UDFs are used with :meth:`pyspark.sql.CoGroupedData.apply`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df1 = spark.createDataFrame( + ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ... ("time", "id", "v1")) + >>> df2 = spark.createDataFrame( + ... [(20000101, 1, "x"), (20000101, 2, "y")], + ... ("time", "id", "v2")) + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP + ... def asof_join(l, r): + ... return pd.merge_asof(l, r, on="time", by="id") + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP + +---------+---+---+---+ + | time| id| v1| v2| + +---------+---+---+---+ + | 20000101| 1|1.0| x| + | 20000102| 1|3.0| x| + | 20000101| 2|2.0| y| + | 20000102| 2|4.0| y| + +---------+---+---+---+ + + Alternatively, the user can define a function that takes three arguments. In this case, + the grouping key(s) will be passed as the first argument and the data will be passed as the + second and third arguments. The grouping key(s) will be passed as a tuple of numpy data + types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two + `pandas.DataFrame` containing all columns from the original Spark DataFrames. + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP + ... def asof_join(k, l, r): + ... if k == (1,): + ... return pd.merge_asof(l, r, on="time", by="id") + ... else: + ... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2']) + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP + +---------+---+---+---+ + | time| id| v1| v2| + +---------+---+---+---+ + | 20000101| 1|1.0| x| + | 20000102| 1|3.0| x| + +---------+---+---+---+ + + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP + + .. note:: The user-defined functions do not support conditional expressions or short circuiting + in boolean expressions and it ends up with being executed all internally. If the functions + can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. + + .. note:: The data type of returned `pandas.Series` from the user-defined functions should be + matched with defined returnType (see :meth:`types.to_arrow_type` and + :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do + conversion on returned data. The conversion is not guaranteed to be correct and results + should be checked for accuracy by users. + """ + + # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that + # are not yet visible to the user. Some of behaviors are buggy and might be changed in the near + # future. The table might have to be eventually documented externally. + # Please see SPARK-28132's PR to see the codes in order to generate the table below. + # + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # |SQL Type \ Pandas Value(Type)|None(object(NoneType))| True(bool)| 1(int8)| 1(int16)| 1(int32)| 1(int64)| 1(uint8)| 1(uint16)| 1(uint32)| 1(uint64)| 1.0(float16)| 1.0(float32)| 1.0(float64)|1970-01-01 00:00:00(datetime64[ns])|1970-01-01 00:00:00-05:00(datetime64[ns, US/Eastern])|a(object(string))| 1(object(Decimal))|[1 2 3](object(array[int32]))| 1.0(float128)|(1+0j)(complex64)|(1+0j)(complex128)|A(category)|1 days 00:00:00(timedelta64[ns])| # noqa + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # | boolean| None| True| True| True| True| True| True| True| True| True| True| True| True| X| X| X| X| X| X| X| X| X| X| # noqa + # | tinyint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| 0| X| # noqa + # | smallint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa + # | int| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| X| X| X| 1| X| X| X| X| X| X| # noqa + # | bigint| None| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 1| 0| 18000000000000| X| 1| X| X| X| X| X| X| # noqa + # | float| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa + # | double| None| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| 1.0| X| X| X| X| X| X| X| X| X| X| # noqa + # | date| None| X| X| X|datetime.date(197...| X| X| X| X| X| X| X| X| datetime.date(197...| datetime.date(197...| X|datetime.date(197...| X| X| X| X| X| X| # noqa + # | timestamp| None| X| X| X| X|datetime.datetime...| X| X| X| X| X| X| X| datetime.datetime...| datetime.datetime...| X|datetime.datetime...| X| X| X| X| X| X| # noqa + # | string| None| ''| ''| ''| '\x01'| '\x01'| ''| ''| '\x01'| '\x01'| ''| ''| ''| X| X| 'a'| X| X| ''| X| ''| X| X| # noqa + # | decimal(10,0)| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| Decimal('1')| X| X| X| X| X| X| # noqa + # | array| None| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| [1, 2, 3]| X| X| X| X| X| # noqa + # | map| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # | struct<_1:int>| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| X| # noqa + # | binary| None|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')| bytearray(b'\x01')| bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'\x01')|bytearray(b'')|bytearray(b'')|bytearray(b'')| bytearray(b'')| bytearray(b'')| bytearray(b'a')| X| X|bytearray(b'')| bytearray(b'')| bytearray(b'')| X| bytearray(b'')| # noqa + # +-----------------------------+----------------------+------------------+------------------+------------------+--------------------+--------------------+------------------+------------------+------------------+------------------+--------------+--------------+--------------+-----------------------------------+-----------------------------------------------------+-----------------+--------------------+-----------------------------+--------------+-----------------+------------------+-----------+--------------------------------+ # noqa + # + # Note: DDL formatted string is used for 'SQL Type' for simplicity. This string can be + # used in `returnType`. + # Note: The values inside of the table are generated by `repr`. + # Note: Python 3.7.3, Pandas 0.24.2 and PyArrow 0.13.0 are used. + # Note: Timezone is KST. + # Note: 'X' means it throws an exception during the conversion. + + # decorator @pandas_udf(returnType, functionType) + is_decorator = f is None or isinstance(f, (str, DataType)) + + if is_decorator: + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + + if functionType is not None: + # @pandas_udf(dataType, functionType=functionType) + # @pandas_udf(returnType=dataType, functionType=functionType) + eval_type = functionType + elif returnType is not None and isinstance(returnType, int): + # @pandas_udf(dataType, functionType) + eval_type = returnType + else: + # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF + else: + return_type = returnType + + if functionType is not None: + eval_type = functionType + else: + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF + + if return_type is None: + raise ValueError("Invalid returnType: returnType can not be None") + + if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]: + raise ValueError("Invalid functionType: " + "functionType must be one the values from PandasUDFType") + + if is_decorator: + return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) + else: + return _create_udf(f=f, returnType=return_type, evalType=eval_type) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.pandas.functions + globs = pyspark.sql.pandas.functions.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.pandas.functions tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.pandas.functions, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/pandas/group_ops.py similarity index 63% rename from python/pyspark/sql/cogroup.py rename to python/pyspark/sql/pandas/group_ops.py index ef87e703bc..00f01d2105 100644 --- a/python/pyspark/sql/cogroup.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -22,7 +22,83 @@ from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -class CoGroupedData(object): +class PandasGroupedOpsMixin(object): + """ + Min-in for pandas grouped operations. Currently, only :class:`GroupedData` + can use this class. + """ + + def apply(self, udf): + """ + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame` are combined as a + :class:`DataFrame`. + + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + .. note:: This function requires a full shuffle. All the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + :param udf: a grouped map user-defined function returned by + :func:`pyspark.sql.functions.pandas_udf`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + from pyspark.sql import GroupedData + + assert isinstance(self, GroupedData) + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "GROUPED_MAP.") + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + + @since(3.0) + def cogroup(self, other): + """ + Cogroups this group with another group so that we can run cogrouped operations. + + See :class:`CoGroupedData` for the operations that can be run. + """ + from pyspark.sql import GroupedData + + assert isinstance(self, GroupedData) + + return PandasCogroupedOps(self, other) + + +class PandasCogroupedOps(object): """ A logical grouping of two :class:`GroupedData`, created by :func:`GroupedData.cogroup`. @@ -124,15 +200,15 @@ class CoGroupedData(object): def _test(): import doctest from pyspark.sql import SparkSession - import pyspark.sql.cogroup - globs = pyspark.sql.cogroup.__dict__.copy() + import pyspark.sql.pandas.group_ops + globs = pyspark.sql.pandas.group_ops.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ - .appName("sql.cogroup tests")\ + .appName("sql.pandas.group tests")\ .getOrCreate() globs['spark'] = spark (failure_count, test_count) = doctest.testmod( - pyspark.sql.cogroup, globs=globs, + pyspark.sql.pandas.group_ops, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py new file mode 100644 index 0000000000..6466d60ea7 --- /dev/null +++ b/python/pyspark/sql/pandas/map_ops.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys + +from pyspark import since +from pyspark.rdd import PythonEvalType + + +class PandasMapOpsMixin(object): + """ + Min-in for pandas map operations. Currently, only :class:`DataFrame` + can use this class. + """ + + @since(3.0) + def mapInPandas(self, udf): + """ + Maps an iterator of batches in the current :class:`DataFrame` using a Pandas user-defined + function and returns the result as a :class:`DataFrame`. + + The user-defined function should take an iterator of `pandas.DataFrame`\\s and return + another iterator of `pandas.DataFrame`\\s. All columns are passed + together as an iterator of `pandas.DataFrame`\\s to the user-defined function and the + returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. + Each `pandas.DataFrame` size can be controlled by + `spark.sql.execution.arrow.maxRecordsPerBatch`. + Its schema must match the returnType of the Pandas user-defined function. + + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame([(1, 21), (2, 30)], + ... ("id", "age")) # doctest: +SKIP + >>> @pandas_udf(df.schema, PandasUDFType.MAP_ITER) # doctest: +SKIP + ... def filter_func(batch_iter): + ... for pdf in batch_iter: + ... yield pdf[pdf.id == 1] + >>> df.mapInPandas(filter_func).show() # doctest: +SKIP + +---+---+ + | id|age| + +---+---+ + | 1| 21| + +---+---+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + from pyspark.sql import Column, DataFrame + + assert isinstance(self, DataFrame) + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.evalType != PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "MAP_ITER.") + + udf_column = udf(*[self[col] for col in self.columns]) + jdf = self._jdf.mapInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.pandas.map_ops + globs = pyspark.sql.pandas.map_ops.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.pandas.map_ops tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.pandas.map_ops, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py new file mode 100644 index 0000000000..4bb5b8fb17 --- /dev/null +++ b/python/pyspark/sql/pandas/serializers.py @@ -0,0 +1,281 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. +""" + +import sys +if sys.version < '3': + from itertools import izip as zip +else: + basestring = unicode = str + xrange = range + +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer + + +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + END_OF_STREAM = -4 + NULL = -5 + START_ARROW_STREAM = -6 + + +class ArrowCollectSerializer(Serializer): + """ + Deserialize a stream of batches followed by batch order information. Used in + PandasConversionMixin._collect_as_arrow() after invoking Dataset.collectAsArrowToPython() + in the JVM. + """ + + def __init__(self): + self.serializer = ArrowStreamSerializer() + + def dump_stream(self, iterator, stream): + return self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + """ + Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields + a list of indices that can be used to put the RecordBatches in the correct order. + """ + # load the batches + for batch in self.serializer.load_stream(stream): + yield batch + + # load the batch order indices or propagate any error that occurred in the JVM + num = read_int(stream) + if num == -1: + error_msg = UTF8Deserializer().loads(stream) + raise RuntimeError("An error occurred while calling " + "ArrowCollectSerializer.load_stream: {}".format(error_msg)) + batch_order = [] + for i in xrange(num): + index = read_int(stream) + batch_order.append(index) + yield batch_order + + def __repr__(self): + return "ArrowCollectSerializer(%s)" % self.serializer + + +class ArrowStreamSerializer(Serializer): + """ + Serializes Arrow record batches as a stream. + """ + + def dump_stream(self, iterator, stream): + import pyarrow as pa + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() + + def load_stream(self, stream): + import pyarrow as pa + reader = pa.ipc.open_stream(stream) + for batch in reader: + yield batch + + def __repr__(self): + return "ArrowStreamSerializer" + + +class ArrowStreamPandasSerializer(ArrowStreamSerializer): + """ + Serializes Pandas.Series as Arrow data with Arrow streaming format. + + :param timezone: A timezone to respect when handling timestamp values + :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation + :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name + """ + + def __init__(self, timezone, safecheck, assign_cols_by_name): + super(ArrowStreamPandasSerializer, self).__init__() + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.pandas.types import _check_series_localize_timestamps + + # If the given column is a date type column, creates a series of datetime.date directly + # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by + # datetime64[ns] type handling. + s = arrow_column.to_pandas(date_as_object=True) + + s = _check_series_localize_timestamps(s, self._timezone) + return s + + def _create_batch(self, series): + """ + Create an Arrow record batch from the given pandas.Series or list of Series, + with optional type. + + :param series: A single pandas.Series, list of Series, or list of (series, arrow_type) + :return: Arrow RecordBatch + """ + import pandas as pd + import pyarrow as pa + from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + def create_array(s, t): + mask = s.isnull() + # Ensure timestamp series are in expected form for Spark internal representation + if t is not None and pa.types.is_timestamp(t): + s = _check_series_convert_timestamps_internal(s, self._timezone) + try: + array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) + except pa.ArrowException as e: + error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \ + "Array (%s). It can be caused by overflows or other unsafe " + \ + "conversions warned by Arrow. Arrow safe type check can be " + \ + "disabled by using SQL config " + \ + "`spark.sql.execution.pandas.arrowSafeTypeConversion`." + raise RuntimeError(error_msg % (s.dtype, t), e) + return array + + arrs = [] + for s, t in series: + if t is not None and pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) + + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif self._assign_cols_by_name and any(isinstance(name, basestring) + for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) + for field in t] + # Assign result columns by position + else: + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) + else: + arrs.append(create_array(s, t)) + + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + + def dump_stream(self, iterator, stream): + """ + Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or + a list of series accompanied by an optional pyarrow type to coerce the data to. + """ + batches = (self._create_batch(series) for series in iterator) + super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) + + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. + """ + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + import pyarrow as pa + for batch in batches: + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] + + def __repr__(self): + return "ArrowStreamPandasSerializer" + + +class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + """ + Serializer used by Python worker to evaluate Pandas UDFs + """ + + def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False): + super(ArrowStreamPandasUDFSerializer, self) \ + .__init__(timezone, safecheck, assign_cols_by_name) + self._df_for_struct = df_for_struct + + def arrow_to_pandas(self, arrow_column): + import pyarrow.types as types + + if self._df_for_struct and types.is_struct(arrow_column.type): + import pandas as pd + series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column) + .rename(field.name) + for column, field in zip(arrow_column.flatten(), arrow_column.type)] + s = pd.concat(series, axis=1) + else: + s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) + return s + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for series in iterator: + batch = self._create_batch(series) + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + + def __repr__(self): + return "ArrowStreamPandasUDFSerializer" + + +class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): + + def load_stream(self, stream): + """ + Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two + lists of pandas.Series. + """ + import pyarrow as pa + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 2: + batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + yield ( + [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()], + [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()] + ) + + elif dataframes_in_group != 0: + raise ValueError( + 'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group)) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py new file mode 100644 index 0000000000..81618bd41f --- /dev/null +++ b/python/pyspark/sql/pandas/types.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Type-specific codes between pandas and PyArrow. Also contains some utils to correct +pandas instances during the type conversion. +""" + +from pyspark.sql.types import * + + +def to_arrow_type(dt): + """ Convert Spark data type to pyarrow type + """ + import pyarrow as pa + if type(dt) == BooleanType: + arrow_type = pa.bool_() + elif type(dt) == ByteType: + arrow_type = pa.int8() + elif type(dt) == ShortType: + arrow_type = pa.int16() + elif type(dt) == IntegerType: + arrow_type = pa.int32() + elif type(dt) == LongType: + arrow_type = pa.int64() + elif type(dt) == FloatType: + arrow_type = pa.float32() + elif type(dt) == DoubleType: + arrow_type = pa.float64() + elif type(dt) == DecimalType: + arrow_type = pa.decimal128(dt.precision, dt.scale) + elif type(dt) == StringType: + arrow_type = pa.string() + elif type(dt) == BinaryType: + arrow_type = pa.binary() + elif type(dt) == DateType: + arrow_type = pa.date32() + elif type(dt) == TimestampType: + # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read + arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == ArrayType: + if type(dt.elementType) in [StructType, TimestampType]: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == StructType: + if any(type(field.dataType) == StructType for field in dt): + raise TypeError("Nested StructType not supported in conversion to Arrow") + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in dt] + arrow_type = pa.struct(fields) + else: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + return arrow_type + + +def to_arrow_schema(schema): + """ Convert a schema from Spark to Arrow + """ + import pyarrow as pa + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in schema] + return pa.schema(fields) + + +def from_arrow_type(at): + """ Convert pyarrow type to Spark data type. + """ + import pyarrow.types as types + if types.is_boolean(at): + spark_type = BooleanType() + elif types.is_int8(at): + spark_type = ByteType() + elif types.is_int16(at): + spark_type = ShortType() + elif types.is_int32(at): + spark_type = IntegerType() + elif types.is_int64(at): + spark_type = LongType() + elif types.is_float32(at): + spark_type = FloatType() + elif types.is_float64(at): + spark_type = DoubleType() + elif types.is_decimal(at): + spark_type = DecimalType(precision=at.precision, scale=at.scale) + elif types.is_string(at): + spark_type = StringType() + elif types.is_binary(at): + spark_type = BinaryType() + elif types.is_date32(at): + spark_type = DateType() + elif types.is_timestamp(at): + spark_type = TimestampType() + elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + spark_type = ArrayType(from_arrow_type(at.value_type)) + elif types.is_struct(at): + if any(types.is_struct(field.type) for field in at): + raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in at]) + else: + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + return spark_type + + +def from_arrow_schema(arrow_schema): + """ Convert schema from Arrow to Spark. + """ + return StructType( + [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) + for field in arrow_schema]) + + +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + +def _check_dataframe_localize_timestamps(pdf, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone + + :param pdf: pandas.DataFrame + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive + """ + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + for column, series in pdf.iteritems(): + pdf[column] = _check_series_localize_timestamps(series, timezone) + return pdf + + +def _check_series_convert_timestamps_internal(s, timezone): + """ + Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for + Spark internal storage + + :param s: a pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone + """ + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64_dtype(s.dtype): + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' + tz = timezone or _get_local_timezone() + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') + elif is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert('UTC') + else: + return s + + +def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param from_timezone: the timezone to convert from. if None then use local timezone + :param to_timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + from pyspark.sql.pandas.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(to_tz).dt.tz_localize(None) + elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: + # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) + else: + return s + + +def _check_series_convert_timestamps_local_tz(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert to. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, None, timezone) + + +def _check_series_convert_timestamps_tz_local(s, timezone): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param s: a pandas.Series + :param timezone: the timezone to convert from. if None then use local timezone + :return pandas.Series where if it is a timestamp, has been converted to tz-naive + """ + return _check_series_convert_timestamps_localize(s, timezone, None) diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py new file mode 100644 index 0000000000..481aa3e643 --- /dev/null +++ b/python/pyspark/sql/pandas/utils.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def require_minimum_pandas_version(): + """ Raise ImportError if minimum version of Pandas is not installed + """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pandas_version = "0.23.2" + + from distutils.version import LooseVersion + try: + import pandas + have_pandas = True + except ImportError: + have_pandas = False + if not have_pandas: + raise ImportError("Pandas >= %s must be installed; however, " + "it was not found." % minimum_pandas_version) + if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError("Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, pandas.__version__)) + + +def require_minimum_pyarrow_version(): + """ Raise ImportError if minimum version of pyarrow is not installed + """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pyarrow_version = "0.15.1" + + from distutils.version import LooseVersion + import os + try: + import pyarrow + have_arrow = True + except ImportError: + have_arrow = False + if not have_arrow: + raise ImportError("PyArrow >= %s must be installed; however, " + "it was not found." % minimum_pyarrow_version) + if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): + raise ImportError("PyArrow >= %s must be installed; however, " + "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1": + raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, " + "please unset ARROW_PRE_0_15_IPC_FORMAT") diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1c3c7778c7..bf858bcf31 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -15,6 +15,8 @@ # limitations under the License. # +# To disallow implicit relative import. Remove this once we drop Python 2. +from __future__ import absolute_import from __future__ import print_function import sys import warnings @@ -25,15 +27,16 @@ if sys.version >= '3': basestring = unicode = str xrange = range else: - from itertools import izip as zip, imap as map + from itertools import imap as map from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame +from pyspark.sql.pandas.conversion import SparkConversionMixin from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, TimestampType, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, \ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \ _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -60,7 +63,7 @@ def _monkey_patch_RDD(sparkSession): RDD.toDF = toDF -class SparkSession(object): +class SparkSession(SparkConversionMixin): """The entry point to programming Spark with the Dataset and DataFrame API. A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as @@ -458,135 +461,6 @@ class SparkSession(object): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema - def _get_numpy_record_dtype(self, rec): - """ - Used when converting a pandas.DataFrame to Spark using to_records(), this will correct - the dtypes of fields in a record so they can be properly loaded into Spark. - :param rec: a numpy record to check field dtypes - :return corrected dtype for a numpy.record or None if no correction needed - """ - import numpy as np - cur_dtypes = rec.dtype - col_names = cur_dtypes.names - record_type_list = [] - has_rec_fix = False - for i in xrange(len(cur_dtypes)): - curr_type = cur_dtypes[i] - # If type is a datetime64 timestamp, convert to microseconds - # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs, - # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417 - if curr_type == np.dtype('datetime64[ns]'): - curr_type = 'datetime64[us]' - has_rec_fix = True - record_type_list.append((str(col_names[i]), curr_type)) - return np.dtype(record_type_list) if has_rec_fix else None - - def _convert_from_pandas(self, pdf, schema, timezone): - """ - Convert a pandas.DataFrame to list of records that can be used to make a DataFrame - :return list of records - """ - if timezone is not None: - from pyspark.sql.types import _check_series_convert_timestamps_tz_local - copied = False - if isinstance(schema, StructType): - for field in schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if s is not pdf[field.name]: - if not copied: - # Copy once if the series is modified to prevent the original - # Pandas DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s - else: - for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(series, timezone) - if s is not series: - if not copied: - # Copy once if the series is modified to prevent the original - # Pandas DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s - - # Convert pandas.DataFrame to list of numpy records - np_records = pdf.to_records(index=False) - - # Check if any columns need to be fixed for Spark to infer properly - if len(np_records) > 0: - record_dtype = self._get_numpy_record_dtype(np_records[0]) - if record_dtype is not None: - return [r.astype(record_dtype).tolist() for r in np_records] - - # Convert list of numpy records to python lists - return [r.tolist() for r in np_records] - - def _create_from_pandas_with_arrow(self, pdf, schema, timezone): - """ - Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting - to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the - data types will be used to coerce the data in Pandas to Arrow conversion. - """ - from pyspark.serializers import ArrowStreamPandasSerializer - from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType - from pyspark.sql.utils import require_minimum_pandas_version, \ - require_minimum_pyarrow_version - - require_minimum_pandas_version() - require_minimum_pyarrow_version() - - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - import pyarrow as pa - - # Create the Spark schema from list of names passed in with Arrow types - if isinstance(schema, (list, tuple)): - arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) - struct = StructType() - for name, field in zip(schema, arrow_schema): - struct.add(name, from_arrow_type(field.type), nullable=field.nullable) - schema = struct - - # Determine arrow types to coerce data when creating batches - if isinstance(schema, StructType): - arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] - elif isinstance(schema, DataType): - raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) - else: - # Any timestamps must be coerced to be compatible with Spark - arrow_types = [to_arrow_type(TimestampType()) - if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None - for t in pdf.dtypes] - - # Slice the DataFrame to be batched - step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up - pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) - - # Create list of Arrow (columns, type) for serializer dump_stream - arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)] - for pdf_slice in pdf_slices] - - jsqlContext = self._wrapped._jsqlContext - - safecheck = self._wrapped._conf.arrowSafeTypeConversion() - col_by_name = True # col by name only applies to StructType columns, can't happen here - ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) - - def reader_func(temp_filename): - return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename) - - def create_RDD_server(): - return self._jvm.ArrowRDDServer(jsqlContext) - - # Create Spark DataFrame from Arrow stream file, using one batch per partition - jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server) - jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext) - df = DataFrame(jdf, self._wrapped) - df._schema = schema - return df - @staticmethod def _create_shell_session(): """ @@ -722,46 +596,12 @@ class SparkSession(object): except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - if self._wrapped._conf.pandasRespectSessionTimeZone(): - timezone = self._wrapped._conf.sessionLocalTimeZone() - else: - timezone = None - - # If no schema supplied by user then get the names of columns only - if schema is None: - schema = [str(x) if not isinstance(x, basestring) else - (x.encode('utf-8') if not isinstance(x, str) else x) - for x in data.columns] - - if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0: - try: - return self._create_from_pandas_with_arrow(data, schema, timezone) - except Exception as e: - from pyspark.util import _exception_message - - if self._wrapped._conf.arrowPySparkFallbackEnabled(): - msg = ( - "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "Attempting non-optimization as " - "'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to " - "true." % _exception_message(e)) - warnings.warn(msg) - else: - msg = ( - "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has " - "reached the error below and will not continue because automatic " - "fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' " - "has been set to false.\n %s" % _exception_message(e)) - warnings.warn(msg) - raise - data = self._convert_from_pandas(data, schema, timezone) + # Create a DataFrame from pandas DataFrame. + return super(SparkSession, self).createDataFrame( + data, schema, verifySchema, samplingRatio) + return self._create_dataframe(data, schema, verifySchema, samplingRatio) + def _create_dataframe(self, data, schema, verifySchema, samplingRatio): if isinstance(schema, StructType): verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index f32513771c..f0930125e3 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -178,7 +178,7 @@ class ArrowTests(ReusedSQLTestCase): self.assertFalse(pdf_ny.equals(pdf_la)) - from pyspark.sql.types import _check_series_convert_timestamps_local_tz + from pyspark.sql.pandas.types import _check_series_convert_timestamps_local_tz pdf_la_corrected = pdf_la.copy() for field in self.schema: if isinstance(field.dataType, TimestampType): @@ -311,7 +311,7 @@ class ArrowTests(ReusedSQLTestCase): self.assertTrue(pdf.equals(pdf_copy)) def test_schema_conversion_roundtrip(self): - from pyspark.sql.types import from_arrow_schema, to_arrow_schema + from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema arrow_schema = to_arrow_schema(self.schema) schema_rt = from_arrow_schema(arrow_schema) self.assertEquals(self.schema, schema_rt) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 86447a346a..94a306a66d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1608,267 +1608,6 @@ register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) -def to_arrow_type(dt): - """ Convert Spark data type to pyarrow type - """ - import pyarrow as pa - if type(dt) == BooleanType: - arrow_type = pa.bool_() - elif type(dt) == ByteType: - arrow_type = pa.int8() - elif type(dt) == ShortType: - arrow_type = pa.int16() - elif type(dt) == IntegerType: - arrow_type = pa.int32() - elif type(dt) == LongType: - arrow_type = pa.int64() - elif type(dt) == FloatType: - arrow_type = pa.float32() - elif type(dt) == DoubleType: - arrow_type = pa.float64() - elif type(dt) == DecimalType: - arrow_type = pa.decimal128(dt.precision, dt.scale) - elif type(dt) == StringType: - arrow_type = pa.string() - elif type(dt) == BinaryType: - arrow_type = pa.binary() - elif type(dt) == DateType: - arrow_type = pa.date32() - elif type(dt) == TimestampType: - # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read - arrow_type = pa.timestamp('us', tz='UTC') - elif type(dt) == ArrayType: - if type(dt.elementType) in [StructType, TimestampType]: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) - arrow_type = pa.list_(to_arrow_type(dt.elementType)) - elif type(dt) == StructType: - if any(type(field.dataType) == StructType for field in dt): - raise TypeError("Nested StructType not supported in conversion to Arrow") - fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) - for field in dt] - arrow_type = pa.struct(fields) - else: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) - return arrow_type - - -def to_arrow_schema(schema): - """ Convert a schema from Spark to Arrow - """ - import pyarrow as pa - fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) - for field in schema] - return pa.schema(fields) - - -def from_arrow_type(at): - """ Convert pyarrow type to Spark data type. - """ - import pyarrow.types as types - if types.is_boolean(at): - spark_type = BooleanType() - elif types.is_int8(at): - spark_type = ByteType() - elif types.is_int16(at): - spark_type = ShortType() - elif types.is_int32(at): - spark_type = IntegerType() - elif types.is_int64(at): - spark_type = LongType() - elif types.is_float32(at): - spark_type = FloatType() - elif types.is_float64(at): - spark_type = DoubleType() - elif types.is_decimal(at): - spark_type = DecimalType(precision=at.precision, scale=at.scale) - elif types.is_string(at): - spark_type = StringType() - elif types.is_binary(at): - spark_type = BinaryType() - elif types.is_date32(at): - spark_type = DateType() - elif types.is_timestamp(at): - spark_type = TimestampType() - elif types.is_list(at): - if types.is_timestamp(at.value_type): - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) - spark_type = ArrayType(from_arrow_type(at.value_type)) - elif types.is_struct(at): - if any(types.is_struct(field.type) for field in at): - raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) - return StructType( - [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) - for field in at]) - else: - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) - return spark_type - - -def from_arrow_schema(arrow_schema): - """ Convert schema from Arrow to Spark. - """ - return StructType( - [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) - for field in arrow_schema]) - - -def _get_local_timezone(): - """ Get local timezone using pytz with environment variable, or dateutil. - - If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone - string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and - it reads system configuration to know the system local timezone. - - See also: - - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 - - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 - """ - import os - return os.environ.get('TZ', 'dateutil/:') - - -def _check_series_localize_timestamps(s, timezone): - """ - Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. - - If the input series is not a timestamp series, then the same series is returned. If the input - series is a timestamp series, then a converted series is returned. - - :param s: pandas.Series - :param timezone: the timezone to convert. if None then use local timezone - :return pandas.Series that have been converted to tz-naive - """ - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(s.dtype): - return s.dt.tz_convert(tz).dt.tz_localize(None) - else: - return s - - -def _check_dataframe_localize_timestamps(pdf, timezone): - """ - Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone - - :param pdf: pandas.DataFrame - :param timezone: the timezone to convert. if None then use local timezone - :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive - """ - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - for column, series in pdf.iteritems(): - pdf[column] = _check_series_localize_timestamps(series, timezone) - return pdf - - -def _check_series_convert_timestamps_internal(s, timezone): - """ - Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for - Spark internal storage - - :param s: a pandas.Series - :param timezone: the timezone to convert. if None then use local timezone - :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone - """ - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64_dtype(s.dtype): - # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive - # timestamp is during the hour when the clock is adjusted backward during due to - # daylight saving time (dst). - # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to - # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize - # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either - # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). - # - # Here we explicit choose to use standard time. This matches the default behavior of - # pytz. - # - # Here are some code to help understand this behavior: - # >>> import datetime - # >>> import pandas as pd - # >>> import pytz - # >>> - # >>> t = datetime.datetime(2015, 11, 1, 1, 30) - # >>> ts = pd.Series([t]) - # >>> tz = pytz.timezone('America/New_York') - # >>> - # >>> ts.dt.tz_localize(tz, ambiguous=True) - # 0 2015-11-01 01:30:00-04:00 - # dtype: datetime64[ns, America/New_York] - # >>> - # >>> ts.dt.tz_localize(tz, ambiguous=False) - # 0 2015-11-01 01:30:00-05:00 - # dtype: datetime64[ns, America/New_York] - # >>> - # >>> str(tz.localize(t)) - # '2015-11-01 01:30:00-05:00' - tz = timezone or _get_local_timezone() - return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') - elif is_datetime64tz_dtype(s.dtype): - return s.dt.tz_convert('UTC') - else: - return s - - -def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): - """ - Convert timestamp to timezone-naive in the specified timezone or local timezone - - :param s: a pandas.Series - :param from_timezone: the timezone to convert from. if None then use local timezone - :param to_timezone: the timezone to convert to. if None then use local timezone - :return pandas.Series where if it is a timestamp, has been converted to tz-naive - """ - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - - import pandas as pd - from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or _get_local_timezone() - to_tz = to_timezone or _get_local_timezone() - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(s.dtype): - return s.dt.tz_convert(to_tz).dt.tz_localize(None) - elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: - # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply( - lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) - else: - return s - - -def _check_series_convert_timestamps_local_tz(s, timezone): - """ - Convert timestamp to timezone-naive in the specified timezone or local timezone - - :param s: a pandas.Series - :param timezone: the timezone to convert to. if None then use local timezone - :return pandas.Series where if it is a timestamp, has been converted to tz-naive - """ - return _check_series_convert_timestamps_localize(s, None, timezone) - - -def _check_series_convert_timestamps_tz_local(s, timezone): - """ - Convert timestamp to timezone-naive in the specified timezone or local timezone - - :param s: a pandas.Series - :param timezone: the timezone to convert from. if None then use local timezone - :return pandas.Series where if it is a timestamp, has been converted to tz-naive - """ - return _check_series_convert_timestamps_localize(s, timezone, None) - - def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 3557c9b1ff..7c6c6e108a 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -23,8 +23,8 @@ import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ - to_arrow_type, to_arrow_schema +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.pandas.types import to_arrow_type from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -46,7 +46,7 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF): - from pyspark.sql.utils import require_minimum_pyarrow_version + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() argspec = _get_argspec(f) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 4260c06f06..147ac3325e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -136,50 +136,6 @@ def toJArray(gateway, jtype, arr): return jarr -def require_minimum_pandas_version(): - """ Raise ImportError if minimum version of Pandas is not installed - """ - # TODO(HyukjinKwon): Relocate and deduplicate the version specification. - minimum_pandas_version = "0.23.2" - - from distutils.version import LooseVersion - try: - import pandas - have_pandas = True - except ImportError: - have_pandas = False - if not have_pandas: - raise ImportError("Pandas >= %s must be installed; however, " - "it was not found." % minimum_pandas_version) - if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): - raise ImportError("Pandas >= %s must be installed; however, " - "your version was %s." % (minimum_pandas_version, pandas.__version__)) - - -def require_minimum_pyarrow_version(): - """ Raise ImportError if minimum version of pyarrow is not installed - """ - # TODO(HyukjinKwon): Relocate and deduplicate the version specification. - minimum_pyarrow_version = "0.15.1" - - from distutils.version import LooseVersion - import os - try: - import pyarrow - have_arrow = True - except ImportError: - have_arrow = False - if not have_arrow: - raise ImportError("PyArrow >= %s must be installed; however, " - "it was not found." % minimum_pyarrow_version) - if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): - raise ImportError("PyArrow >= %s must be installed; however, " - "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) - if os.environ.get("ARROW_PRE_0_15_IPC_FORMAT", "0") == "1": - raise RuntimeError("Arrow legacy IPC format is not supported in PySpark, " - "please unset ARROW_PRE_0_15_IPC_FORMAT") - - def require_test_compiled(): """ Raise Exception if test classes are not compiled """ diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 13800cfa52..085fce6daa 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -29,7 +29,7 @@ from pyspark.util import _exception_message pandas_requirement_message = None try: - from pyspark.sql.utils import require_minimum_pandas_version + from pyspark.sql.pandas.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: # If Pandas version requirement is not satisfied, skip related tests. @@ -37,7 +37,7 @@ except ImportError as e: pyarrow_requirement_message = None try: - from pyspark.sql.utils import require_minimum_pyarrow_version + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: # If Arrow version requirement is not satisfied, skip related tests. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bfa8d97b94..5d498421e2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,8 +39,10 @@ from pyspark.resourceinformation import ResourceInformation from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer -from pyspark.sql.types import to_arrow_type, StructType + BatchedSerializer +from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle diff --git a/python/setup.py b/python/setup.py index 138161ff13..965927a569 100755 --- a/python/setup.py +++ b/python/setup.py @@ -179,6 +179,8 @@ try: 'pyspark.ml.linalg', 'pyspark.ml.param', 'pyspark.sql', + 'pyspark.sql.avro', + 'pyspark.sql.pandas', 'pyspark.streaming', 'pyspark.bin', 'pyspark.sbin', diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index d39019bcda..51150a1b38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -105,7 +105,7 @@ object IntegratedUDFTestUtils extends SQLHelper { Seq( pythonExec, "-c", - "from pyspark.sql.utils import require_minimum_pandas_version;" + + "from pyspark.sql.pandas.utils import require_minimum_pandas_version;" + "require_minimum_pandas_version()"), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! @@ -117,7 +117,7 @@ object IntegratedUDFTestUtils extends SQLHelper { Seq( pythonExec, "-c", - "from pyspark.sql.utils import require_minimum_pyarrow_version;" + + "from pyspark.sql.pandas.utils import require_minimum_pyarrow_version;" + "require_minimum_pyarrow_version()"), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!