[SPARK-23290][SQL][PYTHON] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame.
## What changes were proposed in this pull request? In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail. See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example. This pr modifies to use `datetime.date` for date type as Spark 2.2 does. ## How was this patch tested? Tests modified to fit the new behavior and existing tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #20506 from ueshin/issues/SPARK-23290.
This commit is contained in:
parent
f3f1e14bb7
commit
a24c03138a
|
@ -267,12 +267,15 @@ class ArrowStreamPandasSerializer(Serializer):
|
|||
"""
|
||||
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
|
||||
"""
|
||||
from pyspark.sql.types import _check_dataframe_localize_timestamps
|
||||
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
|
||||
_check_dataframe_localize_timestamps
|
||||
import pyarrow as pa
|
||||
reader = pa.open_stream(stream)
|
||||
schema = from_arrow_schema(reader.schema)
|
||||
for batch in reader:
|
||||
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
|
||||
pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone)
|
||||
pdf = batch.to_pandas()
|
||||
pdf = _check_dataframe_convert_date(pdf, schema)
|
||||
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
|
||||
yield [c for _, c in pdf.iteritems()]
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -1923,7 +1923,8 @@ class DataFrame(object):
|
|||
|
||||
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
|
||||
try:
|
||||
from pyspark.sql.types import _check_dataframe_localize_timestamps
|
||||
from pyspark.sql.types import _check_dataframe_convert_date, \
|
||||
_check_dataframe_localize_timestamps
|
||||
from pyspark.sql.utils import require_minimum_pyarrow_version
|
||||
import pyarrow
|
||||
require_minimum_pyarrow_version()
|
||||
|
@ -1931,6 +1932,7 @@ class DataFrame(object):
|
|||
if tables:
|
||||
table = pyarrow.concat_tables(tables)
|
||||
pdf = table.to_pandas()
|
||||
pdf = _check_dataframe_convert_date(pdf, self.schema)
|
||||
return _check_dataframe_localize_timestamps(pdf, timezone)
|
||||
else:
|
||||
return pd.DataFrame.from_records([], columns=self.columns)
|
||||
|
@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt):
|
|||
"""
|
||||
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
|
||||
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
|
||||
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
|
||||
"""
|
||||
import numpy as np
|
||||
if type(dt) == ByteType:
|
||||
|
@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt):
|
|||
return np.int32
|
||||
elif type(dt) == FloatType:
|
||||
return np.float32
|
||||
elif type(dt) == DateType:
|
||||
return 'datetime64[ns]'
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
@ -2816,7 +2816,7 @@ class SQLTests(ReusedSQLTestCase):
|
|||
self.assertEquals(types[1], np.object)
|
||||
self.assertEquals(types[2], np.bool)
|
||||
self.assertEquals(types[3], np.float32)
|
||||
self.assertEquals(types[4], 'datetime64[ns]')
|
||||
self.assertEquals(types[4], np.object) # datetime.date
|
||||
self.assertEquals(types[5], 'datetime64[ns]')
|
||||
|
||||
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
|
||||
|
@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
ReusedSQLTestCase.setUpClass()
|
||||
|
||||
|
@ -3410,11 +3410,11 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
StructField("7_date_t", DateType(), True),
|
||||
StructField("8_timestamp_t", TimestampType(), True)])
|
||||
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
|
||||
datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
||||
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
||||
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
|
||||
datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
|
||||
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
|
||||
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
|
||||
datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
|
||||
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -3461,7 +3461,9 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
def test_toPandas_arrow_toggle(self):
|
||||
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
||||
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
|
||||
self.assertPandasEqual(pdf_arrow, pdf)
|
||||
expected = self.create_pandas_data_frame()
|
||||
self.assertPandasEqual(expected, pdf)
|
||||
self.assertPandasEqual(expected, pdf_arrow)
|
||||
|
||||
def test_toPandas_respect_session_timezone(self):
|
||||
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
||||
|
@ -4062,18 +4064,42 @@ class ScalarPandasUDF(ReusedSQLTestCase):
|
|||
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
|
||||
df.select(f(col('map'))).collect()
|
||||
|
||||
def test_vectorized_udf_null_date(self):
|
||||
def test_vectorized_udf_dates(self):
|
||||
from pyspark.sql.functions import pandas_udf, col
|
||||
from datetime import date
|
||||
schema = StructType().add("date", DateType())
|
||||
data = [(date(1969, 1, 1),),
|
||||
(date(2012, 2, 2),),
|
||||
(None,),
|
||||
(date(2100, 4, 4),)]
|
||||
schema = StructType().add("idx", LongType()).add("date", DateType())
|
||||
data = [(0, date(1969, 1, 1),),
|
||||
(1, date(2012, 2, 2),),
|
||||
(2, None,),
|
||||
(3, date(2100, 4, 4),)]
|
||||
df = self.spark.createDataFrame(data, schema=schema)
|
||||
date_f = pandas_udf(lambda t: t, returnType=DateType())
|
||||
res = df.select(date_f(col("date")))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
|
||||
date_copy = pandas_udf(lambda t: t, returnType=DateType())
|
||||
df = df.withColumn("date_copy", date_copy(col("date")))
|
||||
|
||||
@pandas_udf(returnType=StringType())
|
||||
def check_data(idx, date, date_copy):
|
||||
import pandas as pd
|
||||
msgs = []
|
||||
is_equal = date.isnull()
|
||||
for i in range(len(idx)):
|
||||
if (is_equal[i] and data[idx[i]][1] is None) or \
|
||||
date[i] == data[idx[i]][1]:
|
||||
msgs.append(None)
|
||||
else:
|
||||
msgs.append(
|
||||
"date values are not equal (date='%s': data[%d][1]='%s')"
|
||||
% (date[i], idx[i], data[idx[i]][1]))
|
||||
return pd.Series(msgs)
|
||||
|
||||
result = df.withColumn("check_data",
|
||||
check_data(col("idx"), col("date"), col("date_copy"))).collect()
|
||||
|
||||
self.assertEquals(len(data), len(result))
|
||||
for i in range(len(result)):
|
||||
self.assertEquals(data[i][1], result[i][1]) # "date" col
|
||||
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
|
||||
self.assertIsNone(result[i][3]) # "check_data" col
|
||||
|
||||
def test_vectorized_udf_timestamps(self):
|
||||
from pyspark.sql.functions import pandas_udf, col
|
||||
|
@ -4114,6 +4140,7 @@ class ScalarPandasUDF(ReusedSQLTestCase):
|
|||
self.assertEquals(len(data), len(result))
|
||||
for i in range(len(result)):
|
||||
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
|
||||
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
|
||||
self.assertIsNone(result[i][3]) # "check_data" col
|
||||
|
||||
def test_vectorized_udf_return_timestamp_tz(self):
|
||||
|
|
|
@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema):
|
|||
for field in arrow_schema])
|
||||
|
||||
|
||||
def _check_dataframe_convert_date(pdf, schema):
|
||||
""" Correct date type value to use datetime.date.
|
||||
|
||||
Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
|
||||
use datetime.date to match the behavior with when Arrow optimization is disabled.
|
||||
|
||||
:param pdf: pandas.DataFrame
|
||||
:param schema: a Spark schema of the pandas.DataFrame
|
||||
"""
|
||||
for field in schema:
|
||||
if type(field.dataType) == DateType:
|
||||
pdf[field.name] = pdf[field.name].dt.date
|
||||
return pdf
|
||||
|
||||
|
||||
def _check_dataframe_localize_timestamps(pdf, timezone):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
|
||||
|
|
Loading…
Reference in a new issue