From 2537fe8cbaf49070137d4b5bc39af078b306c4c8 Mon Sep 17 00:00:00 2001 From: itholic Date: Wed, 7 Jul 2021 15:14:18 +0900 Subject: [PATCH] [SPARK-35929][PYTHON] Support to infer nested dict as a struct when creating a DataFrame ### What changes were proposed in this pull request? Currently, inferring nested structs is always using `MapType`. This behavior causes an issue because it infers the schema with a value type of the first field of the struct as below: ```python data = [{"inside_struct": {"payment": 100.5, "name": "Lee"}}] df = spark.createDataFrame(data) df.show(truncate=False) +--------------------------------+ |inside_struct | +--------------------------------+ |{name -> null, payment -> 100.5}| +--------------------------------+ ``` The "name" became `null`, but it should've been `"Lee"`. In this case, we need to be able to infer the schema with a `StructType` instead of a `MapType`. Therefore, this PR proposes adding an new configuration `spark.sql.pyspark.inferNestedDictAsStruct.enabled` to handle which type is used for inferring nested structs. - When `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is `false` (by default), inferring nested structs by `MapType` - When `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is `true`, inferring nested structs by `StructType` ### Why are the changes needed? Because always inferring the nested structs by `MapType` doesn't work properly for some cases. ### Does this PR introduce _any_ user-facing change? New configuration `spark.sql.pyspark.inferNestedDictAsStruct.enabled` is added. ### How was this patch tested? Added an unit test Closes #33214 from itholic/SPARK-35929. Lead-authored-by: itholic Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/session.py | 13 +++++++--- python/pyspark/sql/tests/test_types.py | 15 +++++++++++ python/pyspark/sql/types.py | 26 ++++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 9 +++++++ 4 files changed, 50 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 740ceb31f7..f3a63de4df 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -436,7 +436,9 @@ class SparkSession(SparkConversionMixin): """ if not data: raise ValueError("can not infer schema from empty dataset") - schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) + infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() + schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct) + for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema @@ -462,11 +464,13 @@ class SparkSession(SparkConversionMixin): raise ValueError("The first row in RDD is empty, " "can not infer schema") + infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() if samplingRatio is None: - schema = _infer_schema(first, names=names) + schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row, names=names)) + schema = _merge_type(schema, _infer_schema( + row, names=names, infer_dict_as_struct=infer_dict_as_struct)) if not _has_nulltype(schema): break else: @@ -475,7 +479,8 @@ class SparkSession(SparkConversionMixin): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema( + row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index eb4caf05d1..0bb1f00b7a 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -204,6 +204,21 @@ class TypesTests(ReusedSQLTestCase): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_infer_nested_dict_as_struct(self): + # SPARK-35929: Test inferring nested dict as a struct type. + NestedRow = Row("f1", "f2") + + with self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}): + data = [NestedRow([{"payment": 200.5, "name": "A"}], [1, 2]), + NestedRow([{"payment": 100.5, "name": "B"}], [2, 3])] + + nestedRdd = self.sc.parallelize(data) + df = self.spark.createDataFrame(nestedRdd) + self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first()) + + df = self.spark.createDataFrame(data) + self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): df = self.spark.createDataFrame([{'a': 1}], ["b"]) self.assertEqual(df.columns, ['b']) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 78c7732f04..e3d8f49c03 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1003,7 +1003,7 @@ if sys.version_info[0] < 4: _array_type_mappings['u'] = StringType -def _infer_type(obj): +def _infer_type(obj, infer_dict_as_struct=False): """Infer the DataType from obj """ if obj is None: @@ -1020,14 +1020,22 @@ def _infer_type(obj): return dataType() if isinstance(obj, dict): - for key, value in obj.items(): - if key is not None and value is not None: - return MapType(_infer_type(key), _infer_type(value), True) - return MapType(NullType(), NullType(), True) + if infer_dict_as_struct: + struct = StructType() + for key, value in obj.items(): + if key is not None and value is not None: + struct.add(key, _infer_type(value, infer_dict_as_struct), True) + return struct + else: + for key, value in obj.items(): + if key is not None and value is not None: + return MapType(_infer_type(key, infer_dict_as_struct), + _infer_type(value, infer_dict_as_struct), True) + return MapType(NullType(), NullType(), True) elif isinstance(obj, list): for v in obj: if v is not None: - return ArrayType(_infer_type(obj[0]), True) + return ArrayType(_infer_type(obj[0], infer_dict_as_struct), True) return ArrayType(NullType(), True) elif isinstance(obj, array): if obj.typecode in _array_type_mappings: @@ -1036,12 +1044,12 @@ def _infer_type(obj): raise TypeError("not supported type: array(%s)" % obj.typecode) else: try: - return _infer_schema(obj) + return _infer_schema(obj, infer_dict_as_struct=infer_dict_as_struct) except TypeError: raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row, names=None): +def _infer_schema(row, names=None, infer_dict_as_struct=False): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1067,7 +1075,7 @@ def _infer_schema(row, names=None): fields = [] for k, v in items: try: - fields.append(StructField(k, _infer_type(v), True)) + fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), True)) except TypeError as e: raise TypeError("Unable to infer the type of the field {}.".format(k)) from e return StructType(fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc53d92c46..e9c5f6e9be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3335,6 +3335,13 @@ object SQLConf { .intConf .createWithDefault(0) + val INFER_NESTED_DICT_AS_STRUCT = buildConf("spark.sql.pyspark.inferNestedDictAsStruct.enabled") + .doc("PySpark's SparkSession.createDataFrame infers the nested dict as a map by default. " + + "When it set to true, it infers the nested dict as a struct.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * @@ -4048,6 +4055,8 @@ class SQLConf extends Serializable with Logging { def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) + def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */