[SPARK-24554][PYTHON][SQL] Add MapType support for PySpark with Arrow
### What changes were proposed in this pull request? This change adds MapType support for PySpark with Arrow, if using pyarrow >= 2.0.0. ### Why are the changes needed? MapType was previous unsupported with Arrow. ### Does this PR introduce _any_ user-facing change? User can now enable MapType for `createDataFrame()`, `toPandas()` with Arrow optimization, and with Pandas UDFs. ### How was this patch tested? Added new PySpark tests for createDataFrame(), toPandas() and Scalar Pandas UDFs. Closes #30393 from BryanCutler/arrow-add-MapType-SPARK-24554. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
dd32f45d20
commit
8e2a0bdce7
|
@ -341,8 +341,9 @@ Supported SQL Types
|
|||
|
||||
.. currentmodule:: pyspark.sql.types
|
||||
|
||||
Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`,
|
||||
Currently, all Spark SQL data types are supported by Arrow-based conversion except
|
||||
:class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`.
|
||||
:class: `MapType` is only supported when using PyArrow 2.0.0 and above.
|
||||
|
||||
Setting Arrow Batch Size
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -22,7 +22,7 @@ from pyspark.rdd import _load_from_socket
|
|||
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
|
||||
from pyspark.sql.types import IntegralType
|
||||
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
|
||||
DoubleType, BooleanType, TimestampType, StructType, DataType
|
||||
DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
|
||||
|
||||
|
@ -100,7 +100,8 @@ class PandasConversionMixin(object):
|
|||
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
|
||||
if use_arrow:
|
||||
try:
|
||||
from pyspark.sql.pandas.types import _check_series_localize_timestamps
|
||||
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
|
||||
_convert_map_items_to_dict
|
||||
import pyarrow
|
||||
# Rename columns to avoid duplicated column names.
|
||||
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
|
||||
|
@ -117,6 +118,9 @@ class PandasConversionMixin(object):
|
|||
if isinstance(field.dataType, TimestampType):
|
||||
pdf[field.name] = \
|
||||
_check_series_localize_timestamps(pdf[field.name], timezone)
|
||||
elif isinstance(field.dataType, MapType):
|
||||
pdf[field.name] = \
|
||||
_convert_map_items_to_dict(pdf[field.name])
|
||||
return pdf
|
||||
else:
|
||||
return pd.DataFrame.from_records([], columns=self.columns)
|
||||
|
|
|
@ -284,7 +284,6 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
should be checked for accuracy by users.
|
||||
|
||||
Currently,
|
||||
:class:`pyspark.sql.types.MapType`,
|
||||
:class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and
|
||||
nested :class:`pyspark.sql.types.StructType`
|
||||
are currently not supported as output types.
|
||||
|
|
|
@ -117,7 +117,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
|||
self._assign_cols_by_name = assign_cols_by_name
|
||||
|
||||
def arrow_to_pandas(self, arrow_column):
|
||||
from pyspark.sql.pandas.types import _check_series_localize_timestamps
|
||||
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
|
||||
_convert_map_items_to_dict
|
||||
import pyarrow
|
||||
|
||||
# If the given column is a date type column, creates a series of datetime.date directly
|
||||
|
@ -127,6 +128,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
|||
|
||||
if pyarrow.types.is_timestamp(arrow_column.type):
|
||||
return _check_series_localize_timestamps(s, self._timezone)
|
||||
elif pyarrow.types.is_map(arrow_column.type):
|
||||
return _convert_map_items_to_dict(s)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
@ -147,7 +150,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
|||
"""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
|
||||
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \
|
||||
_convert_dict_to_map_items
|
||||
from pandas.api.types import is_categorical_dtype
|
||||
# Make input conform to [(series1, type1), (series2, type2), ...]
|
||||
if not isinstance(series, (list, tuple)) or \
|
||||
|
@ -160,6 +164,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
|
|||
# Ensure timestamp series are in expected form for Spark internal representation
|
||||
if t is not None and pa.types.is_timestamp(t):
|
||||
s = _check_series_convert_timestamps_internal(s, self._timezone)
|
||||
elif t is not None and pa.types.is_map(t):
|
||||
s = _convert_dict_to_map_items(s)
|
||||
elif is_categorical_dtype(s.dtype):
|
||||
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
|
||||
s = s.astype(s.dtypes.categories.dtype)
|
||||
|
|
|
@ -20,14 +20,15 @@ Type-specific codes between pandas and PyArrow. Also contains some utils to corr
|
|||
pandas instances during the type conversion.
|
||||
"""
|
||||
|
||||
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
|
||||
DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \
|
||||
StructType, StructField, BooleanType
|
||||
from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \
|
||||
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \
|
||||
ArrayType, MapType, StructType, StructField
|
||||
|
||||
|
||||
def to_arrow_type(dt):
|
||||
""" Convert Spark data type to pyarrow type
|
||||
"""
|
||||
from distutils.version import LooseVersion
|
||||
import pyarrow as pa
|
||||
if type(dt) == BooleanType:
|
||||
arrow_type = pa.bool_()
|
||||
|
@ -58,6 +59,13 @@ def to_arrow_type(dt):
|
|||
if type(dt.elementType) in [StructType, TimestampType]:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
arrow_type = pa.list_(to_arrow_type(dt.elementType))
|
||||
elif type(dt) == MapType:
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
|
||||
if type(dt.keyType) in [StructType, TimestampType] or \
|
||||
type(dt.valueType) in [StructType, TimestampType]:
|
||||
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
|
||||
arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType))
|
||||
elif type(dt) == StructType:
|
||||
if any(type(field.dataType) == StructType for field in dt):
|
||||
raise TypeError("Nested StructType not supported in conversion to Arrow")
|
||||
|
@ -81,6 +89,8 @@ def to_arrow_schema(schema):
|
|||
def from_arrow_type(at):
|
||||
""" Convert pyarrow type to Spark data type.
|
||||
"""
|
||||
from distutils.version import LooseVersion
|
||||
import pyarrow as pa
|
||||
import pyarrow.types as types
|
||||
if types.is_boolean(at):
|
||||
spark_type = BooleanType()
|
||||
|
@ -110,6 +120,12 @@ def from_arrow_type(at):
|
|||
if types.is_timestamp(at.value_type):
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
spark_type = ArrayType(from_arrow_type(at.value_type))
|
||||
elif types.is_map(at):
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
|
||||
if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type):
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type))
|
||||
elif types.is_struct(at):
|
||||
if any(types.is_struct(field.type) for field in at):
|
||||
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
|
||||
|
@ -306,3 +322,23 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
|
|||
`pandas.Series` where if it is a timestamp, has been converted to tz-naive
|
||||
"""
|
||||
return _check_series_convert_timestamps_localize(s, timezone, None)
|
||||
|
||||
|
||||
def _convert_map_items_to_dict(s):
|
||||
"""
|
||||
Convert a series with items as list of (key, value), as made from an Arrow column of map type,
|
||||
to dict for compatibility with non-arrow MapType columns.
|
||||
:param s: pandas.Series of lists of (key, value) pairs
|
||||
:return: pandas.Series of dictionaries
|
||||
"""
|
||||
return s.apply(lambda m: None if m is None else {k: v for k, v in m})
|
||||
|
||||
|
||||
def _convert_dict_to_map_items(s):
|
||||
"""
|
||||
Convert a series of dictionaries to list of (key, value) pairs to match expected data
|
||||
for Arrow column of map type.
|
||||
:param s: pandas.Series of dictionaries
|
||||
:return: pandas.Series of lists of (key, value) pairs
|
||||
"""
|
||||
return s.apply(lambda d: list(d.items()) if d is not None else None)
|
||||
|
|
|
@ -21,13 +21,13 @@ import threading
|
|||
import time
|
||||
import unittest
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from pyspark import SparkContext, SparkConf
|
||||
from pyspark.sql import Row, SparkSession
|
||||
from pyspark.sql.functions import udf
|
||||
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
|
||||
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \
|
||||
ArrayType
|
||||
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||
pandas_requirement_message, pyarrow_requirement_message
|
||||
from pyspark.testing.utils import QuietTest
|
||||
|
@ -114,9 +114,10 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
return pd.DataFrame(data=data_dict)
|
||||
|
||||
def test_toPandas_fallback_enabled(self):
|
||||
ts = datetime.datetime(2015, 11, 1, 0, 30)
|
||||
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
|
||||
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
|
||||
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
|
||||
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
|
||||
df = self.spark.createDataFrame([([ts],)], schema=schema)
|
||||
with QuietTest(self.sc):
|
||||
with self.warnings_lock:
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
|
@ -129,10 +130,10 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertTrue(len(user_warns) > 0)
|
||||
self.assertTrue(
|
||||
"Attempting non-optimization" in str(user_warns[-1]))
|
||||
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
|
||||
assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]}))
|
||||
|
||||
def test_toPandas_fallback_disabled(self):
|
||||
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
|
||||
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
|
||||
df = self.spark.createDataFrame([(None,)], schema=schema)
|
||||
with QuietTest(self.sc):
|
||||
with self.warnings_lock:
|
||||
|
@ -336,6 +337,62 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertTrue(expected[r][e] == result_arrow[r][e] and
|
||||
result[r][e] == result_arrow[r][e])
|
||||
|
||||
def test_createDataFrame_with_map_type(self):
|
||||
map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]
|
||||
|
||||
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
|
||||
schema = "id long, m map<string, long>"
|
||||
|
||||
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
||||
df = self.spark.createDataFrame(pdf, schema=schema)
|
||||
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
|
||||
self.spark.createDataFrame(pdf, schema=schema)
|
||||
else:
|
||||
df_arrow = self.spark.createDataFrame(pdf, schema=schema)
|
||||
|
||||
result = df.collect()
|
||||
result_arrow = df_arrow.collect()
|
||||
|
||||
self.assertEqual(len(result), len(result_arrow))
|
||||
for row, row_arrow in zip(result, result_arrow):
|
||||
i, m = row
|
||||
_, m_arrow = row_arrow
|
||||
self.assertEqual(m, map_data[i])
|
||||
self.assertEqual(m_arrow, map_data[i])
|
||||
|
||||
def test_toPandas_with_map_type(self):
|
||||
pdf = pd.DataFrame({"id": [0, 1, 2, 3],
|
||||
"m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]})
|
||||
|
||||
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
||||
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
|
||||
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
|
||||
df.toPandas()
|
||||
else:
|
||||
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
|
||||
assert_frame_equal(pdf_arrow, pdf_non)
|
||||
|
||||
def test_toPandas_with_map_type_nulls(self):
|
||||
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4],
|
||||
"m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]})
|
||||
|
||||
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
||||
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
|
||||
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
|
||||
df.toPandas()
|
||||
else:
|
||||
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
|
||||
assert_frame_equal(pdf_arrow, pdf_non)
|
||||
|
||||
def test_createDataFrame_with_int_col_names(self):
|
||||
import numpy as np
|
||||
pdf = pd.DataFrame(np.random.rand(4, 2))
|
||||
|
@ -345,26 +402,28 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertEqual(pdf_col_names, df_arrow.columns)
|
||||
|
||||
def test_createDataFrame_fallback_enabled(self):
|
||||
ts = datetime.datetime(2015, 11, 1, 0, 30)
|
||||
with QuietTest(self.sc):
|
||||
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
# we want the warnings to appear even if this test is run from a subclass
|
||||
warnings.simplefilter("always")
|
||||
df = self.spark.createDataFrame(
|
||||
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
|
||||
pd.DataFrame({"a": [[ts]]}), "a: array<timestamp>")
|
||||
# Catch and check the last UserWarning.
|
||||
user_warns = [
|
||||
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
|
||||
self.assertTrue(len(user_warns) > 0)
|
||||
self.assertTrue(
|
||||
"Attempting non-optimization" in str(user_warns[-1]))
|
||||
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
|
||||
self.assertEqual(df.collect(), [Row(a=[ts])])
|
||||
|
||||
def test_createDataFrame_fallback_disabled(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
|
||||
self.spark.createDataFrame(
|
||||
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
|
||||
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
|
||||
"a: array<timestamp>")
|
||||
|
||||
# Regression test for SPARK-23314
|
||||
def test_timestamp_dst(self):
|
||||
|
|
|
@ -176,9 +176,9 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*MapType'):
|
||||
'Invalid return type.*ArrayType.*TimestampType'):
|
||||
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
|
||||
lambda l, r: l, 'id long, v map<int, int>')
|
||||
lambda l, r: l, 'id long, v array<timestamp>')
|
||||
|
||||
def test_wrong_args(self):
|
||||
left = self.data1
|
||||
|
|
|
@ -26,7 +26,7 @@ from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf
|
|||
window
|
||||
from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \
|
||||
LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \
|
||||
StructField, NullType, MapType, TimestampType
|
||||
StructField, NullType, TimestampType
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||
pandas_requirement_message, pyarrow_requirement_message
|
||||
from pyspark.testing.utils import QuietTest
|
||||
|
@ -246,10 +246,10 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*grouped map Pandas UDF.*MapType'):
|
||||
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(
|
||||
lambda pdf: pdf,
|
||||
'id long, v map<int, int>',
|
||||
'id long, v array<timestamp>',
|
||||
PandasUDFType.GROUPED_MAP)
|
||||
|
||||
def test_wrong_args(self):
|
||||
|
@ -276,7 +276,6 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
def test_unsupported_types(self):
|
||||
common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
|
||||
unsupported_types = [
|
||||
StructField('map', MapType(StringType(), IntegerType())),
|
||||
StructField('arr_ts', ArrayType(TimestampType())),
|
||||
StructField('null', NullType()),
|
||||
StructField('struct', StructType([StructField('l', LongType())])),
|
||||
|
|
|
@ -21,7 +21,7 @@ from pyspark.rdd import PythonEvalType
|
|||
from pyspark.sql import Row
|
||||
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
|
||||
udf, pandas_udf, PandasUDFType
|
||||
from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType
|
||||
from pyspark.sql.types import ArrayType, TimestampType
|
||||
from pyspark.sql.utils import AnalysisException
|
||||
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
||||
pandas_requirement_message, pyarrow_requirement_message
|
||||
|
@ -159,7 +159,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
||||
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
|
||||
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
|
||||
def mean_and_std_udf(v):
|
||||
return {v.mean(): v.std()}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import time
|
|||
import unittest
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from pyspark import TaskContext
|
||||
from pyspark.rdd import PythonEvalType
|
||||
|
@ -379,6 +380,20 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
'Invalid return type with scalar Pandas UDFs'):
|
||||
pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
|
||||
|
||||
def test_vectorized_udf_map_type(self):
|
||||
data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, "c": 3},)]
|
||||
schema = StructType([StructField("map", MapType(StringType(), LongType()))])
|
||||
df = self.spark.createDataFrame(data, schema=schema)
|
||||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegex(Exception, "MapType.*not supported"):
|
||||
pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
|
||||
else:
|
||||
map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
|
||||
result = df.select(map_f(col('map')))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
|
||||
def test_vectorized_udf_complex(self):
|
||||
df = self.spark.range(10).select(
|
||||
col('id').cast('int').alias('a'),
|
||||
|
@ -504,8 +519,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*MapType'):
|
||||
pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type)
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
|
||||
|
||||
def test_vectorized_udf_return_scalar(self):
|
||||
df = self.spark.range(10)
|
||||
|
@ -577,8 +592,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*MapType'):
|
||||
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type)
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):
|
||||
|
|
|
@ -1903,7 +1903,7 @@ object SQLConf {
|
|||
"1. pyspark.sql.DataFrame.toPandas " +
|
||||
"2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " +
|
||||
"The following data types are unsupported: " +
|
||||
"MapType, ArrayType of TimestampType, and nested StructType.")
|
||||
"ArrayType of TimestampType, and nested StructType.")
|
||||
.version("3.0.0")
|
||||
.fallbackConf(ARROW_EXECUTION_ENABLED)
|
||||
|
||||
|
|
|
@ -63,10 +63,10 @@ object ArrowWriter {
|
|||
val elementVector = createFieldWriter(vector.getDataVector())
|
||||
new ArrayWriter(vector, elementVector)
|
||||
case (MapType(_, _, _), vector: MapVector) =>
|
||||
val entryWriter = createFieldWriter(vector.getDataVector).asInstanceOf[StructWriter]
|
||||
val keyWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.KEY_NAME))
|
||||
val valueWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.VALUE_NAME))
|
||||
new MapWriter(vector, keyWriter, valueWriter)
|
||||
val structVector = vector.getDataVector.asInstanceOf[StructVector]
|
||||
val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
|
||||
val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
|
||||
new MapWriter(vector, structVector, keyWriter, valueWriter)
|
||||
case (StructType(_), vector: StructVector) =>
|
||||
val children = (0 until vector.size()).map { ordinal =>
|
||||
createFieldWriter(vector.getChildByOrdinal(ordinal))
|
||||
|
@ -331,11 +331,11 @@ private[arrow] class StructWriter(
|
|||
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
|
||||
val struct = input.getStruct(ordinal, children.length)
|
||||
var i = 0
|
||||
valueVector.setIndexDefined(count)
|
||||
while (i < struct.numFields) {
|
||||
children(i).write(struct, i)
|
||||
i += 1
|
||||
}
|
||||
valueVector.setIndexDefined(count)
|
||||
}
|
||||
|
||||
override def finish(): Unit = {
|
||||
|
@ -351,6 +351,7 @@ private[arrow] class StructWriter(
|
|||
|
||||
private[arrow] class MapWriter(
|
||||
val valueVector: MapVector,
|
||||
val structVector: StructVector,
|
||||
val keyWriter: ArrowFieldWriter,
|
||||
val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter {
|
||||
|
||||
|
@ -363,6 +364,7 @@ private[arrow] class MapWriter(
|
|||
val values = map.valueArray()
|
||||
var i = 0
|
||||
while (i < map.numElements()) {
|
||||
structVector.setIndexDefined(keyWriter.count)
|
||||
keyWriter.write(keys, i)
|
||||
valueWriter.write(values, i)
|
||||
i += 1
|
||||
|
|
Loading…
Reference in a new issue