[SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas
## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler <cutlerb@gmail.com> Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009.
This commit is contained in:
parent
7bcc266681
commit
e599837248
|
@ -648,7 +648,9 @@ class SparkSession(object):
|
||||||
|
|
||||||
# If no schema supplied by user then get the names of columns only
|
# If no schema supplied by user then get the names of columns only
|
||||||
if schema is None:
|
if schema is None:
|
||||||
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns]
|
schema = [str(x) if not isinstance(x, basestring) else
|
||||||
|
(x.encode('utf-8') if not isinstance(x, str) else x)
|
||||||
|
for x in data.columns]
|
||||||
|
|
||||||
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
|
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
|
||||||
and len(data) > 0:
|
and len(data) > 0:
|
||||||
|
|
|
@ -3532,6 +3532,15 @@ class ArrowTests(ReusedSQLTestCase):
|
||||||
self.assertTrue(expected[r][e] == result_arrow[r][e] and
|
self.assertTrue(expected[r][e] == result_arrow[r][e] and
|
||||||
result[r][e] == result_arrow[r][e])
|
result[r][e] == result_arrow[r][e])
|
||||||
|
|
||||||
|
def test_createDataFrame_with_int_col_names(self):
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
pdf = pd.DataFrame(np.random.rand(4, 2))
|
||||||
|
df, df_arrow = self._createDataFrame_toggle(pdf)
|
||||||
|
pdf_col_names = [str(c) for c in pdf.columns]
|
||||||
|
self.assertEqual(pdf_col_names, df.columns)
|
||||||
|
self.assertEqual(pdf_col_names, df_arrow.columns)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
|
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
|
||||||
class PandasUDFTests(ReusedSQLTestCase):
|
class PandasUDFTests(ReusedSQLTestCase):
|
||||||
|
|
Loading…
Reference in a new issue