From 8f0e88df03a06a91bb61c6e0d69b1b19e2bfb3f7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 15 Nov 2017 23:35:13 +0900 Subject: [PATCH] [SPARK-20791][PYTHON][FOLLOWUP] Check for unicode column names in createDataFrame with Arrow ## What changes were proposed in this pull request? If schema is passed as a list of unicode strings for column names, they should be re-encoded to 'utf-8' to be consistent. This is similar to the #13097 but for creation of DataFrame using Arrow. ## How was this patch tested? Added new test of using unicode names for schema. Author: Bryan Cutler Closes #19738 from BryanCutler/arrow-createDataFrame-followup-unicode-SPARK-20791. --- python/pyspark/sql/session.py | 7 ++++--- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 589365b083..dbbcfff6db 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -592,6 +592,9 @@ class SparkSession(object): if isinstance(schema, basestring): schema = _parse_datatype_string(schema) + elif isinstance(schema, (list, tuple)): + # Must re-encode any unicode strings to be consistent with StructField names + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] try: import pandas @@ -602,7 +605,7 @@ class SparkSession(object): # If no schema supplied by user then get the names of columns only if schema is None: - schema = [str(x) for x in data.columns] + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: @@ -630,8 +633,6 @@ class SparkSession(object): verify_func(obj) return obj, else: - if isinstance(schema, (list, tuple)): - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj if isinstance(data, RDD): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6356d938db..ef592c2356 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3225,6 +3225,16 @@ class ArrowTests(ReusedSQLTestCase): df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg')) self.assertEquals(df.schema.fieldNames(), list('abcdefg')) + def test_createDataFrame_column_name_encoding(self): + import pandas as pd + pdf = pd.DataFrame({u'a': [1]}) + columns = self.spark.createDataFrame(pdf).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'a') + columns = self.spark.createDataFrame(pdf, [u'b']).columns + self.assertTrue(isinstance(columns[0], str)) + self.assertEquals(columns[0], 'b') + def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc):