[SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas

## What changes were proposed in this pull request?

This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default.

## How was this patch tested?

Added new unit test to create DataFrame with and without the optimization enabled, then compare results.

Author: Bryan Cutler <cutlerb@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>

Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791.
This commit is contained in:
Bryan Cutler 2017-11-13 13:16:01 +09:00 committed by hyukjinkwon
parent 3d90b2cb38
commit 209b9361ac
8 changed files with 254 additions and 43 deletions

View file

@ -475,24 +475,30 @@ class SparkContext(object):
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
try:
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
return RDD(jrdd, self, serializer)
def _serialize_to_jvm(self, data, parallelism, serializer):
"""
Calling the Java parallelize() method with an ArrayList is too slow,
because it sends O(n) Py4J commands. As an alternative, serialized
objects are written to a file and loaded through textFile().
"""
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
try:
serializer.dump_stream(data, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
finally:
# readRDDFromFile eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)
return RDD(jrdd, self, serializer)
def pickleFile(self, name, minPartitions=None):
"""

View file

@ -121,6 +121,7 @@ def launch_gateway(conf=None):
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")

View file

@ -214,6 +214,13 @@ class ArrowSerializer(FramedSerializer):
def _create_batch(series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
from pyspark.sql.types import _check_series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
@ -229,7 +236,8 @@ def _create_batch(series):
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _check_series_convert_timestamps_internal(s.fillna(0))\
.values.astype('datetime64[us]', copy=False)
elif t == pa.date32():
# NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
elif t is not None and t == pa.date32():
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
return s.dt.date
elif t is None or s.dtype == t.to_pandas_dtype():

View file

@ -25,7 +25,7 @@ if sys.version >= '3':
basestring = unicode = str
xrange = range
else:
from itertools import imap as map
from itertools import izip as zip, imap as map
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
@ -417,12 +417,12 @@ class SparkSession(object):
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema
def _get_numpy_record_dtypes(self, rec):
def _get_numpy_record_dtype(self, rec):
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
the dtypes of records so they can be properly loaded into Spark.
:param rec: a numpy record to check dtypes
:return corrected dtypes for a numpy.record or None if no correction needed
the dtypes of fields in a record so they can be properly loaded into Spark.
:param rec: a numpy record to check field dtypes
:return corrected dtype for a numpy.record or None if no correction needed
"""
import numpy as np
cur_dtypes = rec.dtype
@ -438,28 +438,70 @@ class SparkSession(object):
curr_type = 'datetime64[us]'
has_rec_fix = True
record_type_list.append((str(col_names[i]), curr_type))
return record_type_list if has_rec_fix else None
return np.dtype(record_type_list) if has_rec_fix else None
def _convert_from_pandas(self, pdf, schema):
def _convert_from_pandas(self, pdf):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
:return tuple of list of records and schema
:return list of records
"""
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) for x in pdf.columns]
# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
# Check if any columns need to be fixed for Spark to infer properly
if len(np_records) > 0:
record_type_list = self._get_numpy_record_dtypes(np_records[0])
if record_type_list is not None:
return [r.astype(record_type_list).tolist() for r in np_records], schema
record_dtype = self._get_numpy_record_dtype(np_records[0])
if record_dtype is not None:
return [r.astype(record_dtype).tolist() for r in np_records]
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records], schema
return [r.tolist() for r in np_records]
def _create_from_pandas_with_arrow(self, pdf, schema):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
for pdf_slice in pdf_slices]
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
if isinstance(schema, (list, tuple)):
struct = from_arrow_schema(batches[0].schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
@since(2.0)
@ignore_unicode_prefix
@ -557,7 +599,19 @@ class SparkSession(object):
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
data, schema = self._convert_from_pandas(data, schema)
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) for x in data.columns]
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
data = self._convert_from_pandas(data)
if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
@ -576,7 +630,7 @@ class SparkSession(object):
verify_func(obj)
return obj,
else:
if isinstance(schema, list):
if isinstance(schema, (list, tuple)):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj

View file

@ -3127,9 +3127,9 @@ class ArrowTests(ReusedSQLTestCase):
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
@classmethod
def tearDownClass(cls):
@ -3145,6 +3145,17 @@ class ArrowTests(ReusedSQLTestCase):
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
def create_pandas_data_frame(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)
def test_unsupported_datatype(self):
schema = StructType([StructField("decimal", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
@ -3161,21 +3172,15 @@ class ArrowTests(ReusedSQLTestCase):
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
def test_pandas_round_trip(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
pdf = pd.DataFrame(data=data_dict)
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
@ -3187,6 +3192,62 @@ class ArrowTests(ReusedSQLTestCase):
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)
def test_createDataFrame_toggle(self):
pdf = self.create_pandas_data_frame()
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
df_no_arrow = self.spark.createDataFrame(pdf)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
df_arrow = self.spark.createDataFrame(pdf)
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
wrong_schema = StructType(list(reversed(self.schema)))
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
self.spark.createDataFrame(pdf, schema=wrong_schema)
def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
# Test that schema as a list of column names gets applied
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
# Test that schema as tuple of column names gets applied
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
def test_createDataFrame_does_not_modify_input(self):
# Some series get converted for Spark to consume, this makes sure input is unchanged
pdf = self.create_pandas_data_frame()
# Use a nanosecond value to make sure it is not truncated
pdf.ix[0, '7_timestamp_t'] = 1
# Integers with nulls will get NaNs filled with 0 and will be casted
pdf.ix[1, '2_int_t'] = None
pdf_copy = pdf.copy(deep=True)
self.spark.createDataFrame(pdf, schema=self.schema)
self.assertTrue(pdf.equals(pdf_copy))
def test_schema_conversion_roundtrip(self):
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
arrow_schema = to_arrow_schema(self.schema)
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedSQLTestCase):

View file

@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
return arrow_type
def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
# TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
import pyarrow as pa
if at == pa.bool_():
spark_type = BooleanType()
elif at == pa.int8():
spark_type = ByteType()
elif at == pa.int16():
spark_type = ShortType()
elif at == pa.int32():
spark_type = IntegerType()
elif at == pa.int64():
spark_type = LongType()
elif at == pa.float32():
spark_type = FloatType()
elif at == pa.float64():
spark_type = DoubleType()
elif type(at) == pa.DecimalType:
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif at == pa.string():
spark_type = StringType()
elif at == pa.date32():
spark_type = DateType()
elif type(at) == pa.TimestampType:
spark_type = TimestampType()
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])
def _check_dataframe_localize_timestamps(pdf):
"""
Convert timezone aware timestamps to timezone-naive in local time

View file

@ -17,9 +17,12 @@
package org.apache.spark.sql.api.python
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.DataType
private[sql] object PythonSQLUtils {
@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils {
def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
}
/**
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
*
* @param payloadRDD A JavaRDD of ArrowPayloads.
* @param schemaString JSON Formatted Schema for ArrowPayloads.
* @param sqlContext The active [[SQLContext]].
* @return The converted [[DataFrame]].
*/
def arrowPayloadToDataFrame(
payloadRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
}
}

View file

@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
@ -204,4 +206,16 @@ private[sql] object ArrowConverters {
reader.close()
}
}
private[sql] def toDataFrame(
payloadRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
val rdd = payloadRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
}
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
sqlContext.internalCreateDataFrame(rdd, schema)
}
}