[SPARK-24554][PYTHON][SQL] Add MapType support for PySpark with Arrow

### What changes were proposed in this pull request?

This change adds MapType support for PySpark with Arrow, if using pyarrow >= 2.0.0.

### Why are the changes needed?

MapType was previous unsupported with Arrow.

### Does this PR introduce _any_ user-facing change?

User can now enable MapType for `createDataFrame()`, `toPandas()` with Arrow optimization, and with Pandas UDFs.

### How was this patch tested?

Added new PySpark tests for createDataFrame(), toPandas() and Scalar Pandas UDFs.

Closes #30393 from BryanCutler/arrow-add-MapType-SPARK-24554.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Bryan Cutler 2020-11-18 21:18:19 +09:00 committed by HyukjinKwon
parent dd32f45d20
commit 8e2a0bdce7
12 changed files with 157 additions and 36 deletions

View file

@ -341,8 +341,9 @@ Supported SQL Types
.. currentmodule:: pyspark.sql.types
Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`,
Currently, all Spark SQL data types are supported by Arrow-based conversion except
:class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`.
:class: `MapType` is only supported when using PyArrow 2.0.0 and above.
Setting Arrow Batch Size
~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -22,7 +22,7 @@ from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, BooleanType, TimestampType, StructType, DataType
DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
from pyspark.traceback_utils import SCCallSiteSync
@ -100,7 +100,8 @@ class PandasConversionMixin(object):
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow
# Rename columns to avoid duplicated column names.
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
@ -117,6 +118,9 @@ class PandasConversionMixin(object):
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_localize_timestamps(pdf[field.name], timezone)
elif isinstance(field.dataType, MapType):
pdf[field.name] = \
_convert_map_items_to_dict(pdf[field.name])
return pdf
else:
return pd.DataFrame.from_records([], columns=self.columns)

View file

@ -284,7 +284,6 @@ def pandas_udf(f=None, returnType=None, functionType=None):
should be checked for accuracy by users.
Currently,
:class:`pyspark.sql.types.MapType`,
:class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and
nested :class:`pyspark.sql.types.StructType`
are currently not supported as output types.

View file

@ -117,7 +117,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
self._assign_cols_by_name = assign_cols_by_name
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow
# If the given column is a date type column, creates a series of datetime.date directly
@ -127,6 +128,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
if pyarrow.types.is_timestamp(arrow_column.type):
return _check_series_localize_timestamps(s, self._timezone)
elif pyarrow.types.is_map(arrow_column.type):
return _convert_map_items_to_dict(s)
else:
return s
@ -147,7 +150,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \
_convert_dict_to_map_items
from pandas.api.types import is_categorical_dtype
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
@ -160,6 +164,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
elif is_categorical_dtype(s.dtype):
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)

View file

@ -20,14 +20,15 @@ Type-specific codes between pandas and PyArrow. Also contains some utils to corr
pandas instances during the type conversion.
"""
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \
StructType, StructField, BooleanType
from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \
ArrayType, MapType, StructType, StructField
def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
from distutils.version import LooseVersion
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
@ -58,6 +59,13 @@ def to_arrow_type(dt):
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == MapType:
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if type(dt.keyType) in [StructType, TimestampType] or \
type(dt.valueType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
@ -81,6 +89,8 @@ def to_arrow_schema(schema):
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
@ -110,6 +120,12 @@ def from_arrow_type(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_map(at):
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
@ -306,3 +322,23 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
`pandas.Series` where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)
def _convert_map_items_to_dict(s):
"""
Convert a series with items as list of (key, value), as made from an Arrow column of map type,
to dict for compatibility with non-arrow MapType columns.
:param s: pandas.Series of lists of (key, value) pairs
:return: pandas.Series of dictionaries
"""
return s.apply(lambda m: None if m is None else {k: v for k, v in m})
def _convert_dict_to_map_items(s):
"""
Convert a series of dictionaries to list of (key, value) pairs to match expected data
for Arrow column of map type.
:param s: pandas.Series of dictionaries
:return: pandas.Series of lists of (key, value) pairs
"""
return s.apply(lambda d: list(d.items()) if d is not None else None)

View file

@ -21,13 +21,13 @@ import threading
import time
import unittest
import warnings
from distutils.version import LooseVersion
from pyspark import SparkContext, SparkConf
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \
ArrayType
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
@ -114,9 +114,10 @@ class ArrowTests(ReusedSQLTestCase):
return pd.DataFrame(data=data_dict)
def test_toPandas_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([([ts],)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with warnings.catch_warnings(record=True) as warns:
@ -129,10 +130,10 @@ class ArrowTests(ReusedSQLTestCase):
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]}))
def test_toPandas_fallback_disabled(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
@ -336,6 +337,62 @@ class ArrowTests(ReusedSQLTestCase):
self.assertTrue(expected[r][e] == result_arrow[r][e] and
result[r][e] == result_arrow[r][e])
def test_createDataFrame_with_map_type(self):
map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
schema = "id long, m map<string, long>"
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema=schema)
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
self.spark.createDataFrame(pdf, schema=schema)
else:
df_arrow = self.spark.createDataFrame(pdf, schema=schema)
result = df.collect()
result_arrow = df_arrow.collect()
self.assertEqual(len(result), len(result_arrow))
for row, row_arrow in zip(result, result_arrow):
i, m = row
_, m_arrow = row_arrow
self.assertEqual(m, map_data[i])
self.assertEqual(m_arrow, map_data[i])
def test_toPandas_with_map_type(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3],
"m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]})
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)
def test_toPandas_with_map_type_nulls(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4],
"m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]})
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)
def test_createDataFrame_with_int_col_names(self):
import numpy as np
pdf = pd.DataFrame(np.random.rand(4, 2))
@ -345,26 +402,28 @@ class ArrowTests(ReusedSQLTestCase):
self.assertEqual(pdf_col_names, df_arrow.columns)
def test_createDataFrame_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
with warnings.catch_warnings(record=True) as warns:
# we want the warnings to appear even if this test is run from a subclass
warnings.simplefilter("always")
df = self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[ts]]}), "a: array<timestamp>")
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
self.assertEqual(df.collect(), [Row(a=[ts])])
def test_createDataFrame_fallback_disabled(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
"a: array<timestamp>")
# Regression test for SPARK-23314
def test_timestamp_dst(self):

View file

@ -176,9 +176,9 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*MapType'):
'Invalid return type.*ArrayType.*TimestampType'):
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
lambda l, r: l, 'id long, v map<int, int>')
lambda l, r: l, 'id long, v array<timestamp>')
def test_wrong_args(self):
left = self.data1

View file

@ -26,7 +26,7 @@ from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf
window
from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \
LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \
StructField, NullType, MapType, TimestampType
StructField, NullType, TimestampType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
@ -246,10 +246,10 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*grouped map Pandas UDF.*MapType'):
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
'id long, v array<timestamp>',
PandasUDFType.GROUPED_MAP)
def test_wrong_args(self):
@ -276,7 +276,6 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
def test_unsupported_types(self):
common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
unsupported_types = [
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),

View file

@ -21,7 +21,7 @@ from pyspark.rdd import PythonEvalType
from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType
from pyspark.sql.types import ArrayType, TimestampType
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
@ -159,7 +159,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return {v.mean(): v.std()}

View file

@ -22,6 +22,7 @@ import time
import unittest
from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion
from pyspark import TaskContext
from pyspark.rdd import PythonEvalType
@ -379,6 +380,20 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
'Invalid return type with scalar Pandas UDFs'):
pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
def test_vectorized_udf_map_type(self):
data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, "c": 3},)]
schema = StructType([StructField("map", MapType(StringType(), LongType()))])
df = self.spark.createDataFrame(data, schema=schema)
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*not supported"):
pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
else:
map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
result = df.select(map_f(col('map')))
self.assertEquals(df.collect(), result.collect())
def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
@ -504,8 +519,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type)
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
def test_vectorized_udf_return_scalar(self):
df = self.spark.range(10)
@ -577,8 +592,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type)
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):

View file

@ -1903,7 +1903,7 @@ object SQLConf {
"1. pyspark.sql.DataFrame.toPandas " +
"2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " +
"The following data types are unsupported: " +
"MapType, ArrayType of TimestampType, and nested StructType.")
"ArrayType of TimestampType, and nested StructType.")
.version("3.0.0")
.fallbackConf(ARROW_EXECUTION_ENABLED)

View file

@ -63,10 +63,10 @@ object ArrowWriter {
val elementVector = createFieldWriter(vector.getDataVector())
new ArrayWriter(vector, elementVector)
case (MapType(_, _, _), vector: MapVector) =>
val entryWriter = createFieldWriter(vector.getDataVector).asInstanceOf[StructWriter]
val keyWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.KEY_NAME))
val valueWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.VALUE_NAME))
new MapWriter(vector, keyWriter, valueWriter)
val structVector = vector.getDataVector.asInstanceOf[StructVector]
val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
new MapWriter(vector, structVector, keyWriter, valueWriter)
case (StructType(_), vector: StructVector) =>
val children = (0 until vector.size()).map { ordinal =>
createFieldWriter(vector.getChildByOrdinal(ordinal))
@ -331,11 +331,11 @@ private[arrow] class StructWriter(
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val struct = input.getStruct(ordinal, children.length)
var i = 0
valueVector.setIndexDefined(count)
while (i < struct.numFields) {
children(i).write(struct, i)
i += 1
}
valueVector.setIndexDefined(count)
}
override def finish(): Unit = {
@ -351,6 +351,7 @@ private[arrow] class StructWriter(
private[arrow] class MapWriter(
val valueVector: MapVector,
val structVector: StructVector,
val keyWriter: ArrowFieldWriter,
val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter {
@ -363,6 +364,7 @@ private[arrow] class MapWriter(
val values = map.valueArray()
var i = 0
while (i < map.numElements()) {
structVector.setIndexDefined(keyWriter.count)
keyWriter.write(keys, i)
valueWriter.write(values, i)
i += 1