[SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas
## What changes were proposed in this pull request? This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default. ## How was this patch tested? Added new unit test to create DataFrame with and without the optimization enabled, then compare results. Author: Bryan Cutler <cutlerb@gmail.com> Author: Takuya UESHIN <ueshin@databricks.com> Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791.
This commit is contained in:
parent
3d90b2cb38
commit
209b9361ac
|
@ -475,24 +475,30 @@ class SparkContext(object):
|
|||
return xrange(getStart(split), getStart(split + 1), step)
|
||||
|
||||
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
|
||||
# Calling the Java parallelize() method with an ArrayList is too slow,
|
||||
# because it sends O(n) Py4J commands. As an alternative, serialized
|
||||
# objects are written to a file and loaded through textFile().
|
||||
|
||||
# Make sure we distribute data evenly if it's smaller than self.batchSize
|
||||
if "__len__" not in dir(c):
|
||||
c = list(c) # Make it a list so we can compute its length
|
||||
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
|
||||
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
|
||||
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
|
||||
return RDD(jrdd, self, serializer)
|
||||
|
||||
def _serialize_to_jvm(self, data, parallelism, serializer):
|
||||
"""
|
||||
Calling the Java parallelize() method with an ArrayList is too slow,
|
||||
because it sends O(n) Py4J commands. As an alternative, serialized
|
||||
objects are written to a file and loaded through textFile().
|
||||
"""
|
||||
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
|
||||
try:
|
||||
# Make sure we distribute data evenly if it's smaller than self.batchSize
|
||||
if "__len__" not in dir(c):
|
||||
c = list(c) # Make it a list so we can compute its length
|
||||
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
|
||||
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
|
||||
serializer.dump_stream(c, tempFile)
|
||||
serializer.dump_stream(data, tempFile)
|
||||
tempFile.close()
|
||||
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
|
||||
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
|
||||
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
|
||||
finally:
|
||||
# readRDDFromFile eagerily reads the file so we can delete right after.
|
||||
os.unlink(tempFile.name)
|
||||
return RDD(jrdd, self, serializer)
|
||||
|
||||
def pickleFile(self, name, minPartitions=None):
|
||||
"""
|
||||
|
|
|
@ -121,6 +121,7 @@ def launch_gateway(conf=None):
|
|||
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
|
||||
# TODO(davies): move into sql
|
||||
java_import(gateway.jvm, "org.apache.spark.sql.*")
|
||||
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
|
||||
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
|
||||
java_import(gateway.jvm, "scala.Tuple2")
|
||||
|
||||
|
|
|
@ -214,6 +214,13 @@ class ArrowSerializer(FramedSerializer):
|
|||
|
||||
|
||||
def _create_batch(series):
|
||||
"""
|
||||
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
|
||||
|
||||
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
|
||||
:return: Arrow RecordBatch
|
||||
"""
|
||||
|
||||
from pyspark.sql.types import _check_series_convert_timestamps_internal
|
||||
import pyarrow as pa
|
||||
# Make input conform to [(series1, type1), (series2, type2), ...]
|
||||
|
@ -229,7 +236,8 @@ def _create_batch(series):
|
|||
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
|
||||
return _check_series_convert_timestamps_internal(s.fillna(0))\
|
||||
.values.astype('datetime64[us]', copy=False)
|
||||
elif t == pa.date32():
|
||||
# NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
|
||||
elif t is not None and t == pa.date32():
|
||||
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
|
||||
return s.dt.date
|
||||
elif t is None or s.dtype == t.to_pandas_dtype():
|
||||
|
|
|
@ -25,7 +25,7 @@ if sys.version >= '3':
|
|||
basestring = unicode = str
|
||||
xrange = range
|
||||
else:
|
||||
from itertools import imap as map
|
||||
from itertools import izip as zip, imap as map
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||
|
@ -417,12 +417,12 @@ class SparkSession(object):
|
|||
data = [schema.toInternal(row) for row in data]
|
||||
return self._sc.parallelize(data), schema
|
||||
|
||||
def _get_numpy_record_dtypes(self, rec):
|
||||
def _get_numpy_record_dtype(self, rec):
|
||||
"""
|
||||
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
|
||||
the dtypes of records so they can be properly loaded into Spark.
|
||||
:param rec: a numpy record to check dtypes
|
||||
:return corrected dtypes for a numpy.record or None if no correction needed
|
||||
the dtypes of fields in a record so they can be properly loaded into Spark.
|
||||
:param rec: a numpy record to check field dtypes
|
||||
:return corrected dtype for a numpy.record or None if no correction needed
|
||||
"""
|
||||
import numpy as np
|
||||
cur_dtypes = rec.dtype
|
||||
|
@ -438,28 +438,70 @@ class SparkSession(object):
|
|||
curr_type = 'datetime64[us]'
|
||||
has_rec_fix = True
|
||||
record_type_list.append((str(col_names[i]), curr_type))
|
||||
return record_type_list if has_rec_fix else None
|
||||
return np.dtype(record_type_list) if has_rec_fix else None
|
||||
|
||||
def _convert_from_pandas(self, pdf, schema):
|
||||
def _convert_from_pandas(self, pdf):
|
||||
"""
|
||||
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
|
||||
:return tuple of list of records and schema
|
||||
:return list of records
|
||||
"""
|
||||
# If no schema supplied by user then get the names of columns only
|
||||
if schema is None:
|
||||
schema = [str(x) for x in pdf.columns]
|
||||
|
||||
# Convert pandas.DataFrame to list of numpy records
|
||||
np_records = pdf.to_records(index=False)
|
||||
|
||||
# Check if any columns need to be fixed for Spark to infer properly
|
||||
if len(np_records) > 0:
|
||||
record_type_list = self._get_numpy_record_dtypes(np_records[0])
|
||||
if record_type_list is not None:
|
||||
return [r.astype(record_type_list).tolist() for r in np_records], schema
|
||||
record_dtype = self._get_numpy_record_dtype(np_records[0])
|
||||
if record_dtype is not None:
|
||||
return [r.astype(record_dtype).tolist() for r in np_records]
|
||||
|
||||
# Convert list of numpy records to python lists
|
||||
return [r.tolist() for r in np_records], schema
|
||||
return [r.tolist() for r in np_records]
|
||||
|
||||
def _create_from_pandas_with_arrow(self, pdf, schema):
|
||||
"""
|
||||
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
|
||||
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
|
||||
data types will be used to coerce the data in Pandas to Arrow conversion.
|
||||
"""
|
||||
from pyspark.serializers import ArrowSerializer, _create_batch
|
||||
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
|
||||
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
|
||||
|
||||
# Determine arrow types to coerce data when creating batches
|
||||
if isinstance(schema, StructType):
|
||||
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
|
||||
elif isinstance(schema, DataType):
|
||||
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
|
||||
else:
|
||||
# Any timestamps must be coerced to be compatible with Spark
|
||||
arrow_types = [to_arrow_type(TimestampType())
|
||||
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
|
||||
for t in pdf.dtypes]
|
||||
|
||||
# Slice the DataFrame to be batched
|
||||
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
|
||||
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
|
||||
|
||||
# Create Arrow record batches
|
||||
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
|
||||
for pdf_slice in pdf_slices]
|
||||
|
||||
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
|
||||
if isinstance(schema, (list, tuple)):
|
||||
struct = from_arrow_schema(batches[0].schema)
|
||||
for i, name in enumerate(schema):
|
||||
struct.fields[i].name = name
|
||||
struct.names[i] = name
|
||||
schema = struct
|
||||
|
||||
# Create the Spark DataFrame directly from the Arrow data and schema
|
||||
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
|
||||
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
|
||||
jrdd, schema.json(), self._wrapped._jsqlContext)
|
||||
df = DataFrame(jdf, self._wrapped)
|
||||
df._schema = schema
|
||||
return df
|
||||
|
||||
@since(2.0)
|
||||
@ignore_unicode_prefix
|
||||
|
@ -557,7 +599,19 @@ class SparkSession(object):
|
|||
except Exception:
|
||||
has_pandas = False
|
||||
if has_pandas and isinstance(data, pandas.DataFrame):
|
||||
data, schema = self._convert_from_pandas(data, schema)
|
||||
|
||||
# If no schema supplied by user then get the names of columns only
|
||||
if schema is None:
|
||||
schema = [str(x) for x in data.columns]
|
||||
|
||||
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
|
||||
and len(data) > 0:
|
||||
try:
|
||||
return self._create_from_pandas_with_arrow(data, schema)
|
||||
except Exception as e:
|
||||
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
|
||||
# Fallback to create DataFrame without arrow if raise some exception
|
||||
data = self._convert_from_pandas(data)
|
||||
|
||||
if isinstance(schema, StructType):
|
||||
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
|
||||
|
@ -576,7 +630,7 @@ class SparkSession(object):
|
|||
verify_func(obj)
|
||||
return obj,
|
||||
else:
|
||||
if isinstance(schema, list):
|
||||
if isinstance(schema, (list, tuple)):
|
||||
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
|
||||
prepare = lambda obj: obj
|
||||
|
||||
|
|
|
@ -3127,9 +3127,9 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
StructField("5_double_t", DoubleType(), True),
|
||||
StructField("6_date_t", DateType(), True),
|
||||
StructField("7_timestamp_t", TimestampType(), True)])
|
||||
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
||||
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
|
||||
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
|
||||
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
||||
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
|
||||
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -3145,6 +3145,17 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
|
||||
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
|
||||
|
||||
def create_pandas_data_frame(self):
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
data_dict = {}
|
||||
for j, name in enumerate(self.schema.names):
|
||||
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
|
||||
# need to convert these to numpy types first
|
||||
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
|
||||
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
|
||||
return pd.DataFrame(data=data_dict)
|
||||
|
||||
def test_unsupported_datatype(self):
|
||||
schema = StructType([StructField("decimal", DecimalType(), True)])
|
||||
df = self.spark.createDataFrame([(None,)], schema=schema)
|
||||
|
@ -3161,21 +3172,15 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
def test_toPandas_arrow_toggle(self):
|
||||
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
||||
pdf = df.toPandas()
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
||||
try:
|
||||
pdf = df.toPandas()
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
||||
pdf_arrow = df.toPandas()
|
||||
self.assertFramesEqual(pdf_arrow, pdf)
|
||||
|
||||
def test_pandas_round_trip(self):
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
data_dict = {}
|
||||
for j, name in enumerate(self.schema.names):
|
||||
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
|
||||
# need to convert these to numpy types first
|
||||
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
|
||||
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
|
||||
pdf = pd.DataFrame(data=data_dict)
|
||||
pdf = self.create_pandas_data_frame()
|
||||
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
||||
pdf_arrow = df.toPandas()
|
||||
self.assertFramesEqual(pdf_arrow, pdf)
|
||||
|
@ -3187,6 +3192,62 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertEqual(pdf.columns[0], "i")
|
||||
self.assertTrue(pdf.empty)
|
||||
|
||||
def test_createDataFrame_toggle(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
||||
try:
|
||||
df_no_arrow = self.spark.createDataFrame(pdf)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
||||
df_arrow = self.spark.createDataFrame(pdf)
|
||||
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
|
||||
|
||||
def test_createDataFrame_with_schema(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
df = self.spark.createDataFrame(pdf, schema=self.schema)
|
||||
self.assertEquals(self.schema, df.schema)
|
||||
pdf_arrow = df.toPandas()
|
||||
self.assertFramesEqual(pdf_arrow, pdf)
|
||||
|
||||
def test_createDataFrame_with_incorrect_schema(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
wrong_schema = StructType(list(reversed(self.schema)))
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
|
||||
self.spark.createDataFrame(pdf, schema=wrong_schema)
|
||||
|
||||
def test_createDataFrame_with_names(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
# Test that schema as a list of column names gets applied
|
||||
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
|
||||
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
|
||||
# Test that schema as tuple of column names gets applied
|
||||
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
|
||||
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
|
||||
|
||||
def test_createDataFrame_with_single_data_type(self):
|
||||
import pandas as pd
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
|
||||
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
|
||||
|
||||
def test_createDataFrame_does_not_modify_input(self):
|
||||
# Some series get converted for Spark to consume, this makes sure input is unchanged
|
||||
pdf = self.create_pandas_data_frame()
|
||||
# Use a nanosecond value to make sure it is not truncated
|
||||
pdf.ix[0, '7_timestamp_t'] = 1
|
||||
# Integers with nulls will get NaNs filled with 0 and will be casted
|
||||
pdf.ix[1, '2_int_t'] = None
|
||||
pdf_copy = pdf.copy(deep=True)
|
||||
self.spark.createDataFrame(pdf, schema=self.schema)
|
||||
self.assertTrue(pdf.equals(pdf_copy))
|
||||
|
||||
def test_schema_conversion_roundtrip(self):
|
||||
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
|
||||
arrow_schema = to_arrow_schema(self.schema)
|
||||
schema_rt = from_arrow_schema(arrow_schema)
|
||||
self.assertEquals(self.schema, schema_rt)
|
||||
|
||||
|
||||
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
|
||||
class VectorizedUDFTests(ReusedSQLTestCase):
|
||||
|
|
|
@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
|
|||
return arrow_type
|
||||
|
||||
|
||||
def to_arrow_schema(schema):
|
||||
""" Convert a schema from Spark to Arrow
|
||||
"""
|
||||
import pyarrow as pa
|
||||
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
|
||||
for field in schema]
|
||||
return pa.schema(fields)
|
||||
|
||||
|
||||
def from_arrow_type(at):
|
||||
""" Convert pyarrow type to Spark data type.
|
||||
"""
|
||||
# TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
|
||||
import pyarrow as pa
|
||||
if at == pa.bool_():
|
||||
spark_type = BooleanType()
|
||||
elif at == pa.int8():
|
||||
spark_type = ByteType()
|
||||
elif at == pa.int16():
|
||||
spark_type = ShortType()
|
||||
elif at == pa.int32():
|
||||
spark_type = IntegerType()
|
||||
elif at == pa.int64():
|
||||
spark_type = LongType()
|
||||
elif at == pa.float32():
|
||||
spark_type = FloatType()
|
||||
elif at == pa.float64():
|
||||
spark_type = DoubleType()
|
||||
elif type(at) == pa.DecimalType:
|
||||
spark_type = DecimalType(precision=at.precision, scale=at.scale)
|
||||
elif at == pa.string():
|
||||
spark_type = StringType()
|
||||
elif at == pa.date32():
|
||||
spark_type = DateType()
|
||||
elif type(at) == pa.TimestampType:
|
||||
spark_type = TimestampType()
|
||||
else:
|
||||
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
|
||||
return spark_type
|
||||
|
||||
|
||||
def from_arrow_schema(arrow_schema):
|
||||
""" Convert schema from Arrow to Spark.
|
||||
"""
|
||||
return StructType(
|
||||
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
|
||||
for field in arrow_schema])
|
||||
|
||||
|
||||
def _check_dataframe_localize_timestamps(pdf):
|
||||
"""
|
||||
Convert timezone aware timestamps to timezone-naive in local time
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.api.python
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
||||
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
|
||||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.execution.arrow.ArrowConverters
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
private[sql] object PythonSQLUtils {
|
||||
|
@ -29,4 +32,19 @@ private[sql] object PythonSQLUtils {
|
|||
def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
|
||||
FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
|
||||
}
|
||||
|
||||
/**
|
||||
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
|
||||
*
|
||||
* @param payloadRDD A JavaRDD of ArrowPayloads.
|
||||
* @param schemaString JSON Formatted Schema for ArrowPayloads.
|
||||
* @param sqlContext The active [[SQLContext]].
|
||||
* @return The converted [[DataFrame]].
|
||||
*/
|
||||
def arrowPayloadToDataFrame(
|
||||
payloadRDD: JavaRDD[Array[Byte]],
|
||||
schemaString: String,
|
||||
sqlContext: SQLContext): DataFrame = {
|
||||
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ import org.apache.arrow.vector.schema.ArrowRecordBatch
|
|||
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -204,4 +206,16 @@ private[sql] object ArrowConverters {
|
|||
reader.close()
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def toDataFrame(
|
||||
payloadRDD: JavaRDD[Array[Byte]],
|
||||
schemaString: String,
|
||||
sqlContext: SQLContext): DataFrame = {
|
||||
val rdd = payloadRDD.rdd.mapPartitions { iter =>
|
||||
val context = TaskContext.get()
|
||||
ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
|
||||
}
|
||||
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
|
||||
sqlContext.internalCreateDataFrame(rdd, schema)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue