diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 88ac4134a0..7b81a0be84 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -37,6 +37,7 @@ from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.streaming import DataStreamWriter +from pyspark.sql.types import IntegralType from pyspark.sql.types import * __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -1891,14 +1892,20 @@ class DataFrame(object): "if using spark.sql.execution.arrow.enable=true" raise ImportError("%s\n%s" % (e.message, msg)) else: + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + dtype = {} for field in self.schema: pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): dtype[field.name] = pandas_type - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - for f, t in dtype.items(): pdf[f] = pdf[f].astype(t, copy=False) return pdf diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ab76c48e00..3db8bee203 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2564,6 +2564,18 @@ class SQLTests(ReusedPySparkTestCase): self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas_avoid_astype(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", IntegerType()) + data = [(1, "foo", 16777220), (None, "bar", None)] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.float64) + def test_create_dataframe_from_array_of_long(self): import array data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]