[SPARK-20791][PYTHON][FOLLOWUP] Check for unicode column names in createDataFrame with Arrow

## What changes were proposed in this pull request?

If schema is passed as a list of unicode strings for column names, they should be re-encoded to 'utf-8' to be consistent.  This is similar to the #13097 but for creation of DataFrame using Arrow.

## How was this patch tested?

Added new test of using unicode names for schema.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #19738 from BryanCutler/arrow-createDataFrame-followup-unicode-SPARK-20791.
This commit is contained in:
Bryan Cutler 2017-11-15 23:35:13 +09:00 committed by hyukjinkwon
parent dce1610ae3
commit 8f0e88df03
2 changed files with 14 additions and 3 deletions

View file

@ -592,6 +592,9 @@ class SparkSession(object):
if isinstance(schema, basestring):
schema = _parse_datatype_string(schema)
elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with StructField names
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
try:
import pandas
@ -602,7 +605,7 @@ class SparkSession(object):
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) for x in data.columns]
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns]
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
@ -630,8 +633,6 @@ class SparkSession(object):
verify_func(obj)
return obj,
else:
if isinstance(schema, (list, tuple)):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj
if isinstance(data, RDD):

View file

@ -3225,6 +3225,16 @@ class ArrowTests(ReusedSQLTestCase):
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
def test_createDataFrame_column_name_encoding(self):
import pandas as pd
pdf = pd.DataFrame({u'a': [1]})
columns = self.spark.createDataFrame(pdf).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'a')
columns = self.spark.createDataFrame(pdf, [u'b']).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'b')
def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):