[SPARK-20791][PYTHON][FOLLOWUP] Check for unicode column names in createDataFrame with Arrow
## What changes were proposed in this pull request? If schema is passed as a list of unicode strings for column names, they should be re-encoded to 'utf-8' to be consistent. This is similar to the #13097 but for creation of DataFrame using Arrow. ## How was this patch tested? Added new test of using unicode names for schema. Author: Bryan Cutler <cutlerb@gmail.com> Closes #19738 from BryanCutler/arrow-createDataFrame-followup-unicode-SPARK-20791.
This commit is contained in:
parent
dce1610ae3
commit
8f0e88df03
|
@ -592,6 +592,9 @@ class SparkSession(object):
|
|||
|
||||
if isinstance(schema, basestring):
|
||||
schema = _parse_datatype_string(schema)
|
||||
elif isinstance(schema, (list, tuple)):
|
||||
# Must re-encode any unicode strings to be consistent with StructField names
|
||||
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
|
||||
|
||||
try:
|
||||
import pandas
|
||||
|
@ -602,7 +605,7 @@ class SparkSession(object):
|
|||
|
||||
# If no schema supplied by user then get the names of columns only
|
||||
if schema is None:
|
||||
schema = [str(x) for x in data.columns]
|
||||
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns]
|
||||
|
||||
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
|
||||
and len(data) > 0:
|
||||
|
@ -630,8 +633,6 @@ class SparkSession(object):
|
|||
verify_func(obj)
|
||||
return obj,
|
||||
else:
|
||||
if isinstance(schema, (list, tuple)):
|
||||
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
|
||||
prepare = lambda obj: obj
|
||||
|
||||
if isinstance(data, RDD):
|
||||
|
|
|
@ -3225,6 +3225,16 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
|
||||
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
|
||||
|
||||
def test_createDataFrame_column_name_encoding(self):
|
||||
import pandas as pd
|
||||
pdf = pd.DataFrame({u'a': [1]})
|
||||
columns = self.spark.createDataFrame(pdf).columns
|
||||
self.assertTrue(isinstance(columns[0], str))
|
||||
self.assertEquals(columns[0], 'a')
|
||||
columns = self.spark.createDataFrame(pdf, [u'b']).columns
|
||||
self.assertTrue(isinstance(columns[0], str))
|
||||
self.assertEquals(columns[0], 'b')
|
||||
|
||||
def test_createDataFrame_with_single_data_type(self):
|
||||
import pandas as pd
|
||||
with QuietTest(self.sc):
|
||||
|
|
Loading…
Reference in a new issue