[SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames
## What changes were proposed in this pull request? This changes the calls of `toPandas()` and `createDataFrame()` to use the Arrow stream format, when Arrow is enabled. Previously, Arrow data was written to byte arrays where each chunk is an output of the Arrow file format. This was mainly due to constraints at the time, and caused some overhead by writing the schema/footer on each chunk of data and then having to read multiple Arrow file inputs and concat them together. Using the Arrow stream format has improved these by increasing performance, lower memory overhead for the average case, and simplified the code. Here are the details of this change: **toPandas()** _Before:_ Spark internal rows are converted to Arrow file format, each group of records is a complete Arrow file which contains the schema and other metadata. Next a collect is done and an Array of Arrow files is the result. After that each Arrow file is sent to Python driver which then loads each file and concats them to a single Arrow DataFrame. _After:_ Spark internal rows are converted to ArrowRecordBatches directly, which is the simplest Arrow component for IPC data transfers. The driver JVM then immediately starts serving data to Python as an Arrow stream, sending the schema first. It then starts a Spark job with a custom handler that sends Arrow RecordBatches to Python. Partitions arriving in order are sent immediately, and out-of-order partitions are buffered until the ones that precede it come in. This improves performance, simplifies memory usage on executors, and improves the average memory usage on the JVM driver. Since the order of partitions must be preserved, the worst case is that the first partition will be the last to arrive all data must be buffered in memory until then. This case is no worse that before when doing a full collect. **createDataFrame()** _Before:_ A Pandas DataFrame is split into parts and each part is made into an Arrow file. Then each file is prefixed by the buffer size and written to a temp file. The temp file is read and each Arrow file is parallelized as a byte array. _After:_ A Pandas DataFrame is split into parts, then an Arrow stream is written to a temp file where each part is an ArrowRecordBatch. The temp file is read as a stream and the Arrow messages are examined. If the message is an ArrowRecordBatch, the data is saved as a byte array. After reading the file, each ArrowRecordBatch is parallelized as a byte array. This has slightly more processing than before because we must look each Arrow message to extract the record batches, but performance ends up a litle better. It is cleaner in the sense that IPC from Python to JVM is done over a single Arrow stream. ## How was this patch tested? Added new unit tests for the additions to ArrowConverters in Scala, existing tests for Python. ## Performance Tests - toPandas Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each. Test code ```python df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand()) for i in range(5): start = time.time() _ = df.toPandas() elapsed = time.time() - start ``` Current Master | This PR ---------------------|------------ 5.803557 | 5.16207 5.409119 | 5.133671 5.493509 | 5.147513 5.433107 | 5.105243 5.488757 | 5.018685 Avg Master | Avg This PR ------------------|-------------- 5.5256098 | 5.1134364 Speedup of **1.08060595** ## Performance Tests - createDataFrame Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `createDataFrame()` and get the first record. Took the average best time of 5 runs/5 loops each. Test code ```python def run(): pdf = pd.DataFrame(np.random.rand(10000000, 10)) spark.createDataFrame(pdf).first() for i in range(6): start = time.time() run() elapsed = time.time() - start gc.collect() print("Run %d: %f" % (i, elapsed)) ``` Current Master | This PR --------------------|---------- 6.234608 | 5.665641 6.32144 | 5.3475 6.527859 | 5.370803 6.95089 | 5.479151 6.235046 | 5.529167 Avg Master | Avg This PR ---------------|---------------- 6.4539686 | 5.4784524 Speedup of **1.178064192** ## Memory Improvements **toPandas()** The most significant improvement is reduction of the upper bound space complexity in the JVM driver. Before, the entire dataset was collected in the JVM first before sending it to Python. With this change, as soon as a partition is collected, the result handler immediately sends it to Python, so the upper bound is the size of the largest partition. Also, using the Arrow stream format is more efficient because the schema is written once per stream, followed by record batches. The schema is now only send from driver JVM to Python. Before, multiple Arrow file formats were used that each contained the schema. This duplicated schema was created in the executors, sent to the driver JVM, and then Python where all but the first one received are discarded. I verified the upper bound limit by running a test that would collect data that would exceed the amount of driver JVM memory available. Using these settings on a standalone cluster: ``` spark.driver.memory 1g spark.executor.memory 5g spark.sql.execution.arrow.enabled true spark.sql.execution.arrow.fallback.enabled false spark.sql.execution.arrow.maxRecordsPerBatch 0 spark.driver.maxResultSize 2g ``` Test code: ```python from pyspark.sql.functions import rand df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()) df.toPandas() ``` This makes total data size of 33554432×8×4 = 1073741824 With the current master, it fails with OOM but passes using this PR. **createDataFrame()** No significant change in memory except that using the stream format instead of separate file formats avoids duplicated the schema, similar to toPandas above. The process of reading the stream and parallelizing the batches does cause the record batch message metadata to be copied, but it's size is insignificant. Closes #21546 from BryanCutler/arrow-toPandas-stream-SPARK-23030. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
parent
ff8dcc1d4c
commit
82c18c240a
|
@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging {
|
|||
* data collected from this job, and the secret for authentication.
|
||||
*/
|
||||
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
|
||||
serveToStream(threadName) { out =>
|
||||
writeIteratorToStream(items, new DataOutputStream(out))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a socket server and background thread to execute the writeFunc
|
||||
* with the given OutputStream.
|
||||
*
|
||||
* The socket server can only accept one connection, or close if no connection
|
||||
* in 15 seconds.
|
||||
*
|
||||
* Once a connection comes in, it will execute the block of code and pass in
|
||||
* the socket output stream.
|
||||
*
|
||||
* The thread will terminate after the block of code is executed or any
|
||||
* exceptions happen.
|
||||
*/
|
||||
private[spark] def serveToStream(
|
||||
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
|
||||
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
|
||||
// Close the socket if no connection in 15 seconds
|
||||
serverSocket.setSoTimeout(15000)
|
||||
|
@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging {
|
|||
val sock = serverSocket.accept()
|
||||
authHelper.authClient(sock)
|
||||
|
||||
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
|
||||
val out = new BufferedOutputStream(sock.getOutputStream)
|
||||
Utils.tryWithSafeFinally {
|
||||
writeIteratorToStream(items, out)
|
||||
writeFunc(out)
|
||||
} {
|
||||
out.close()
|
||||
sock.close()
|
||||
|
|
|
@ -494,10 +494,14 @@ class SparkContext(object):
|
|||
c = list(c) # Make it a list so we can compute its length
|
||||
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
|
||||
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
|
||||
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
|
||||
|
||||
def reader_func(temp_filename):
|
||||
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
|
||||
|
||||
jrdd = self._serialize_to_jvm(c, serializer, reader_func)
|
||||
return RDD(jrdd, self, serializer)
|
||||
|
||||
def _serialize_to_jvm(self, data, parallelism, serializer):
|
||||
def _serialize_to_jvm(self, data, serializer, reader_func):
|
||||
"""
|
||||
Calling the Java parallelize() method with an ArrayList is too slow,
|
||||
because it sends O(n) Py4J commands. As an alternative, serialized
|
||||
|
@ -507,8 +511,7 @@ class SparkContext(object):
|
|||
try:
|
||||
serializer.dump_stream(data, tempFile)
|
||||
tempFile.close()
|
||||
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
|
||||
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
|
||||
return reader_func(tempFile.name)
|
||||
finally:
|
||||
# readRDDFromFile eagerily reads the file so we can delete right after.
|
||||
os.unlink(tempFile.name)
|
||||
|
|
|
@ -185,27 +185,31 @@ class FramedSerializer(Serializer):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class ArrowSerializer(FramedSerializer):
|
||||
class ArrowStreamSerializer(Serializer):
|
||||
"""
|
||||
Serializes bytes as Arrow data with the Arrow file format.
|
||||
Serializes Arrow record batches as a stream.
|
||||
"""
|
||||
|
||||
def dumps(self, batch):
|
||||
def dump_stream(self, iterator, stream):
|
||||
import pyarrow as pa
|
||||
import io
|
||||
sink = io.BytesIO()
|
||||
writer = pa.RecordBatchFileWriter(sink, batch.schema)
|
||||
writer = None
|
||||
try:
|
||||
for batch in iterator:
|
||||
if writer is None:
|
||||
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
|
||||
writer.write_batch(batch)
|
||||
finally:
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
return sink.getvalue()
|
||||
|
||||
def loads(self, obj):
|
||||
def load_stream(self, stream):
|
||||
import pyarrow as pa
|
||||
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
|
||||
return reader.read_all()
|
||||
reader = pa.open_stream(stream)
|
||||
for batch in reader:
|
||||
yield batch
|
||||
|
||||
def __repr__(self):
|
||||
return "ArrowSerializer"
|
||||
return "ArrowStreamSerializer"
|
||||
|
||||
|
||||
def _create_batch(series, timezone):
|
||||
|
|
|
@ -29,7 +29,7 @@ import warnings
|
|||
|
||||
from pyspark import copy_func, since, _NoValue
|
||||
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
||||
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
|
||||
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
|
||||
UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
|
@ -2118,10 +2118,9 @@ class DataFrame(object):
|
|||
from pyspark.sql.types import _check_dataframe_convert_date, \
|
||||
_check_dataframe_localize_timestamps
|
||||
import pyarrow
|
||||
|
||||
tables = self._collectAsArrow()
|
||||
if tables:
|
||||
table = pyarrow.concat_tables(tables)
|
||||
batches = self._collectAsArrow()
|
||||
if len(batches) > 0:
|
||||
table = pyarrow.Table.from_batches(batches)
|
||||
pdf = table.to_pandas()
|
||||
pdf = _check_dataframe_convert_date(pdf, self.schema)
|
||||
return _check_dataframe_localize_timestamps(pdf, timezone)
|
||||
|
@ -2170,14 +2169,14 @@ class DataFrame(object):
|
|||
|
||||
def _collectAsArrow(self):
|
||||
"""
|
||||
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
|
||||
and available.
|
||||
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
|
||||
and available on driver and worker Python environments.
|
||||
|
||||
.. note:: Experimental.
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
sock_info = self._jdf.collectAsArrowToPython()
|
||||
return list(_load_from_socket(sock_info, ArrowSerializer()))
|
||||
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
|
||||
|
||||
##########################################################################################
|
||||
# Pandas compatibility
|
||||
|
|
|
@ -501,7 +501,7 @@ class SparkSession(object):
|
|||
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
|
||||
data types will be used to coerce the data in Pandas to Arrow conversion.
|
||||
"""
|
||||
from pyspark.serializers import ArrowSerializer, _create_batch
|
||||
from pyspark.serializers import ArrowStreamSerializer, _create_batch
|
||||
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
|
||||
from pyspark.sql.utils import require_minimum_pandas_version, \
|
||||
require_minimum_pyarrow_version
|
||||
|
@ -539,10 +539,12 @@ class SparkSession(object):
|
|||
struct.names[i] = name
|
||||
schema = struct
|
||||
|
||||
# Create the Spark DataFrame directly from the Arrow data and schema
|
||||
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
|
||||
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
|
||||
jrdd, schema.json(), self._wrapped._jsqlContext)
|
||||
def reader_func(temp_filename):
|
||||
return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
|
||||
self._wrapped._jsqlContext, temp_filename, schema.json())
|
||||
|
||||
# Create Spark DataFrame from Arrow stream file, using one batch per partition
|
||||
jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
|
||||
df = DataFrame(jdf, self._wrapped)
|
||||
df._schema = schema
|
||||
return df
|
||||
|
|
|
@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
|
|||
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
|
||||
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
|
||||
import org.apache.spark.sql.execution.command._
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.execution.python.EvaluatePython
|
||||
|
@ -3273,13 +3273,49 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
|
||||
* Collect a Dataset as Arrow batches and serve stream to PySpark.
|
||||
*/
|
||||
private[sql] def collectAsArrowToPython(): Array[Any] = {
|
||||
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
|
||||
|
||||
withAction("collectAsArrowToPython", queryExecution) { plan =>
|
||||
val iter: Iterator[Array[Byte]] =
|
||||
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
|
||||
PythonRDD.serveIterator(iter, "serve-Arrow")
|
||||
PythonRDD.serveToStream("serve-Arrow") { out =>
|
||||
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
|
||||
val arrowBatchRdd = toArrowBatchRdd(plan)
|
||||
val numPartitions = arrowBatchRdd.partitions.length
|
||||
|
||||
// Store collection results for worst case of 1 to N-1 partitions
|
||||
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
|
||||
var lastIndex = -1 // index of last partition written
|
||||
|
||||
// Handler to eagerly write partitions to Python in order
|
||||
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
|
||||
// If result is from next partition in order
|
||||
if (index - 1 == lastIndex) {
|
||||
batchWriter.writeBatches(arrowBatches.iterator)
|
||||
lastIndex += 1
|
||||
// Write stored partitions that come next in order
|
||||
while (lastIndex < results.length && results(lastIndex) != null) {
|
||||
batchWriter.writeBatches(results(lastIndex).iterator)
|
||||
results(lastIndex) = null
|
||||
lastIndex += 1
|
||||
}
|
||||
// After last batch, end the stream
|
||||
if (lastIndex == results.length) {
|
||||
batchWriter.end()
|
||||
}
|
||||
} else {
|
||||
// Store partitions received out of order
|
||||
results(index - 1) = arrowBatches
|
||||
}
|
||||
}
|
||||
|
||||
sparkSession.sparkContext.runJob(
|
||||
arrowBatchRdd,
|
||||
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
|
||||
0 until numPartitions,
|
||||
handlePartitionBatches)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3386,20 +3422,20 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
}
|
||||
|
||||
/** Convert to an RDD of ArrowPayload byte arrays */
|
||||
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
|
||||
/** Convert to an RDD of serialized ArrowRecordBatches. */
|
||||
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
|
||||
val schemaCaptured = this.schema
|
||||
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
|
||||
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
|
||||
plan.execute().mapPartitionsInternal { iter =>
|
||||
val context = TaskContext.get()
|
||||
ArrowConverters.toPayloadIterator(
|
||||
ArrowConverters.toBatchIterator(
|
||||
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
|
||||
}
|
||||
}
|
||||
|
||||
// This is only used in tests, for now.
|
||||
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
|
||||
toArrowPayload(queryExecution.executedPlan)
|
||||
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
|
||||
toArrowBatchRdd(queryExecution.executedPlan)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.api.python
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
||||
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
|
||||
|
@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils {
|
|||
}
|
||||
|
||||
/**
|
||||
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
|
||||
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
|
||||
* using each serialized ArrowRecordBatch as a partition.
|
||||
*
|
||||
* @param payloadRDD A JavaRDD of ArrowPayloads.
|
||||
* @param schemaString JSON Formatted Schema for ArrowPayloads.
|
||||
* @param sqlContext The active [[SQLContext]].
|
||||
* @return The converted [[DataFrame]].
|
||||
* @param filename File to read the Arrow stream from.
|
||||
* @param schemaString JSON Formatted Spark schema for Arrow batches.
|
||||
* @return A new [[DataFrame]].
|
||||
*/
|
||||
def arrowPayloadToDataFrame(
|
||||
payloadRDD: JavaRDD[Array[Byte]],
|
||||
schemaString: String,
|
||||
sqlContext: SQLContext): DataFrame = {
|
||||
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
|
||||
def arrowReadStreamFromFile(
|
||||
sqlContext: SQLContext,
|
||||
filename: String,
|
||||
schemaString: String): DataFrame = {
|
||||
val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
|
||||
ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,73 +17,75 @@
|
|||
|
||||
package org.apache.spark.sql.execution.arrow
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.nio.channels.Channels
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream}
|
||||
import java.nio.channels.{Channels, SeekableByteChannel}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.arrow.flatbuf.MessageHeader
|
||||
import org.apache.arrow.memory.BufferAllocator
|
||||
import org.apache.arrow.vector._
|
||||
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter}
|
||||
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
|
||||
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
|
||||
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
|
||||
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.{ByteBufferOutputStream, Utils}
|
||||
|
||||
|
||||
/**
|
||||
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
|
||||
* Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format.
|
||||
*/
|
||||
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {
|
||||
private[sql] class ArrowBatchStreamWriter(
|
||||
schema: StructType,
|
||||
out: OutputStream,
|
||||
timeZoneId: String) {
|
||||
|
||||
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||
val writeChannel = new WriteChannel(Channels.newChannel(out))
|
||||
|
||||
// Write the Arrow schema first, before batches
|
||||
MessageSerializer.serialize(writeChannel, arrowSchema)
|
||||
|
||||
/**
|
||||
* Convert the ArrowPayload to an ArrowRecordBatch.
|
||||
* Consume iterator to write each serialized ArrowRecordBatch to the stream.
|
||||
*/
|
||||
def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
|
||||
ArrowConverters.byteArrayToBatch(payload, allocator)
|
||||
def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = {
|
||||
arrowBatchIter.foreach(writeChannel.write)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the ArrowPayload as a type that can be served to Python.
|
||||
* End the Arrow stream, does not close output stream.
|
||||
*/
|
||||
def asPythonSerializable: Array[Byte] = payload
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterator interface to iterate over Arrow record batches and return rows
|
||||
*/
|
||||
private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
|
||||
|
||||
/**
|
||||
* Return the schema loaded from the Arrow record batch being iterated over
|
||||
*/
|
||||
def schema: StructType
|
||||
def end(): Unit = {
|
||||
ArrowStreamWriter.writeEndOfStream(writeChannel)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object ArrowConverters {
|
||||
|
||||
/**
|
||||
* Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload
|
||||
* by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
|
||||
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
|
||||
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
|
||||
*/
|
||||
private[sql] def toPayloadIterator(
|
||||
private[sql] def toBatchIterator(
|
||||
rowIter: Iterator[InternalRow],
|
||||
schema: StructType,
|
||||
maxRecordsPerBatch: Int,
|
||||
timeZoneId: String,
|
||||
context: TaskContext): Iterator[ArrowPayload] = {
|
||||
context: TaskContext): Iterator[Array[Byte]] = {
|
||||
|
||||
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||
val allocator =
|
||||
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)
|
||||
ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue)
|
||||
|
||||
val root = VectorSchemaRoot.create(arrowSchema, allocator)
|
||||
val unloader = new VectorUnloader(root)
|
||||
val arrowWriter = ArrowWriter.create(root)
|
||||
|
||||
context.addTaskCompletionListener[Unit] { _ =>
|
||||
|
@ -91,7 +93,7 @@ private[sql] object ArrowConverters {
|
|||
allocator.close()
|
||||
}
|
||||
|
||||
new Iterator[ArrowPayload] {
|
||||
new Iterator[Array[Byte]] {
|
||||
|
||||
override def hasNext: Boolean = rowIter.hasNext || {
|
||||
root.close()
|
||||
|
@ -99,9 +101,9 @@ private[sql] object ArrowConverters {
|
|||
false
|
||||
}
|
||||
|
||||
override def next(): ArrowPayload = {
|
||||
override def next(): Array[Byte] = {
|
||||
val out = new ByteArrayOutputStream()
|
||||
val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
|
||||
val writeChannel = new WriteChannel(Channels.newChannel(out))
|
||||
|
||||
Utils.tryWithSafeFinally {
|
||||
var rowCount = 0
|
||||
|
@ -111,45 +113,46 @@ private[sql] object ArrowConverters {
|
|||
rowCount += 1
|
||||
}
|
||||
arrowWriter.finish()
|
||||
writer.writeBatch()
|
||||
val batch = unloader.getRecordBatch()
|
||||
MessageSerializer.serialize(writeChannel, batch)
|
||||
batch.close()
|
||||
} {
|
||||
arrowWriter.reset()
|
||||
writer.close()
|
||||
}
|
||||
|
||||
new ArrowPayload(out.toByteArray)
|
||||
out.toByteArray
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator
|
||||
* and the schema from the first batch of Arrow data read.
|
||||
* Maps iterator from serialized ArrowRecordBatches to InternalRows.
|
||||
*/
|
||||
private[sql] def fromPayloadIterator(
|
||||
payloadIter: Iterator[ArrowPayload],
|
||||
context: TaskContext): ArrowRowIterator = {
|
||||
private[sql] def fromBatchIterator(
|
||||
arrowBatchIter: Iterator[Array[Byte]],
|
||||
schema: StructType,
|
||||
timeZoneId: String,
|
||||
context: TaskContext): Iterator[InternalRow] = {
|
||||
val allocator =
|
||||
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
|
||||
ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue)
|
||||
|
||||
new ArrowRowIterator {
|
||||
private var reader: ArrowFileReader = null
|
||||
private var schemaRead = StructType(Seq.empty)
|
||||
private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty
|
||||
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||
val root = VectorSchemaRoot.create(arrowSchema, allocator)
|
||||
|
||||
new Iterator[InternalRow] {
|
||||
private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty
|
||||
|
||||
context.addTaskCompletionListener[Unit] { _ =>
|
||||
closeReader()
|
||||
root.close()
|
||||
allocator.close()
|
||||
}
|
||||
|
||||
override def schema: StructType = schemaRead
|
||||
|
||||
override def hasNext: Boolean = rowIter.hasNext || {
|
||||
closeReader()
|
||||
if (payloadIter.hasNext) {
|
||||
if (arrowBatchIter.hasNext) {
|
||||
rowIter = nextBatch()
|
||||
true
|
||||
} else {
|
||||
root.close()
|
||||
allocator.close()
|
||||
false
|
||||
}
|
||||
|
@ -157,19 +160,11 @@ private[sql] object ArrowConverters {
|
|||
|
||||
override def next(): InternalRow = rowIter.next()
|
||||
|
||||
private def closeReader(): Unit = {
|
||||
if (reader != null) {
|
||||
reader.close()
|
||||
reader = null
|
||||
}
|
||||
}
|
||||
|
||||
private def nextBatch(): Iterator[InternalRow] = {
|
||||
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
|
||||
reader = new ArrowFileReader(in, allocator)
|
||||
reader.loadNextBatch() // throws IOException
|
||||
val root = reader.getVectorSchemaRoot // throws IOException
|
||||
schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)
|
||||
val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator)
|
||||
val vectorLoader = new VectorLoader(root)
|
||||
vectorLoader.load(arrowRecordBatch)
|
||||
arrowRecordBatch.close()
|
||||
|
||||
val columns = root.getFieldVectors.asScala.map { vector =>
|
||||
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
|
||||
|
@ -183,34 +178,106 @@ private[sql] object ArrowConverters {
|
|||
}
|
||||
|
||||
/**
|
||||
* Convert a byte array to an ArrowRecordBatch.
|
||||
* Load a serialized ArrowRecordBatch.
|
||||
*/
|
||||
private[arrow] def byteArrayToBatch(
|
||||
private[arrow] def loadBatch(
|
||||
batchBytes: Array[Byte],
|
||||
allocator: BufferAllocator): ArrowRecordBatch = {
|
||||
val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
|
||||
val reader = new ArrowFileReader(in, allocator)
|
||||
|
||||
// Read a batch from a byte stream, ensure the reader is closed
|
||||
Utils.tryWithSafeFinally {
|
||||
val root = reader.getVectorSchemaRoot // throws IOException
|
||||
val unloader = new VectorUnloader(root)
|
||||
reader.loadNextBatch() // throws IOException
|
||||
unloader.getRecordBatch
|
||||
} {
|
||||
reader.close()
|
||||
}
|
||||
val in = new ByteArrayInputStream(batchBytes)
|
||||
MessageSerializer.deserializeRecordBatch(
|
||||
new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
|
||||
*/
|
||||
private[sql] def toDataFrame(
|
||||
payloadRDD: JavaRDD[Array[Byte]],
|
||||
arrowBatchRDD: JavaRDD[Array[Byte]],
|
||||
schemaString: String,
|
||||
sqlContext: SQLContext): DataFrame = {
|
||||
val rdd = payloadRDD.rdd.mapPartitions { iter =>
|
||||
val context = TaskContext.get()
|
||||
ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
|
||||
}
|
||||
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
|
||||
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
|
||||
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
|
||||
val context = TaskContext.get()
|
||||
ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
|
||||
}
|
||||
sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
|
||||
*/
|
||||
private[sql] def readArrowStreamFromFile(
|
||||
sqlContext: SQLContext,
|
||||
filename: String): JavaRDD[Array[Byte]] = {
|
||||
Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
|
||||
// Create array to consume iterator so that we can safely close the file
|
||||
val batches = getBatchesFromStream(fileStream.getChannel).toArray
|
||||
// Parallelize the record batches to create an RDD
|
||||
JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches.
|
||||
*/
|
||||
private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = {
|
||||
|
||||
// Iterate over the serialized Arrow RecordBatch messages from a stream
|
||||
new Iterator[Array[Byte]] {
|
||||
var batch: Array[Byte] = readNextBatch()
|
||||
|
||||
override def hasNext: Boolean = batch != null
|
||||
|
||||
override def next(): Array[Byte] = {
|
||||
val prevBatch = batch
|
||||
batch = readNextBatch()
|
||||
prevBatch
|
||||
}
|
||||
|
||||
// This gets the next serialized ArrowRecordBatch by reading message metadata to check if it
|
||||
// is a RecordBatch message and then returning the complete serialized message which consists
|
||||
// of a int32 length, serialized message metadata and a serialized RecordBatch message body
|
||||
def readNextBatch(): Array[Byte] = {
|
||||
val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in))
|
||||
if (msgMetadata == null) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Get the length of the body, which has not been read at this point
|
||||
val bodyLength = msgMetadata.getMessageBodyLength.toInt
|
||||
|
||||
// Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages
|
||||
if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) {
|
||||
|
||||
// Buffer backed output large enough to hold the complete serialized message
|
||||
val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength)
|
||||
|
||||
// Write message metadata to ByteBuffer output stream
|
||||
MessageSerializer.writeMessageBuffer(
|
||||
new WriteChannel(Channels.newChannel(bbout)),
|
||||
msgMetadata.getMessageLength,
|
||||
msgMetadata.getMessageBuffer)
|
||||
|
||||
// Get a zero-copy ByteBuffer with already contains message metadata, must close first
|
||||
bbout.close()
|
||||
val bb = bbout.toByteBuffer
|
||||
bb.position(bbout.getCount())
|
||||
|
||||
// Read message body directly into the ByteBuffer to avoid copy, return backed byte array
|
||||
bb.limit(bb.capacity())
|
||||
JavaUtils.readFully(in, bb)
|
||||
bb.array()
|
||||
} else {
|
||||
if (bodyLength > 0) {
|
||||
// Skip message body if not a RecordBatch
|
||||
in.position(in.position() + bodyLength)
|
||||
}
|
||||
|
||||
// Proceed to next message
|
||||
readNextBatch()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
*/
|
||||
package org.apache.spark.sql.execution.arrow
|
||||
|
||||
import java.io.File
|
||||
import java.io.{ByteArrayOutputStream, DataOutputStream, File}
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.text.SimpleDateFormat
|
||||
|
@ -26,7 +26,7 @@ import com.google.common.io.Files
|
|||
import org.apache.arrow.memory.RootAllocator
|
||||
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
|
||||
import org.apache.arrow.vector.ipc.JsonFileReader
|
||||
import org.apache.arrow.vector.util.Validator
|
||||
import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.{SparkException, TaskContext}
|
||||
|
@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
|
||||
test("collect to arrow record batch") {
|
||||
val indexData = (1 to 6).toDF("i")
|
||||
val arrowPayloads = indexData.toArrowPayload.collect()
|
||||
assert(arrowPayloads.nonEmpty)
|
||||
assert(arrowPayloads.length == indexData.rdd.getNumPartitions)
|
||||
val arrowBatches = indexData.toArrowBatchRdd.collect()
|
||||
assert(arrowBatches.nonEmpty)
|
||||
assert(arrowBatches.length == indexData.rdd.getNumPartitions)
|
||||
val allocator = new RootAllocator(Long.MaxValue)
|
||||
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
|
||||
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
|
||||
val rowCount = arrowRecordBatches.map(_.getLength).sum
|
||||
assert(rowCount === indexData.count())
|
||||
arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0))
|
||||
|
@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
|}
|
||||
""".stripMargin
|
||||
|
||||
val arrowPayloads = testData2.toArrowPayload.collect()
|
||||
// NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload
|
||||
assert(arrowPayloads.length === 2)
|
||||
val arrowBatches = testData2.toArrowBatchRdd.collect()
|
||||
// NOTE: testData2 should have 2 partitions -> 2 arrow batches
|
||||
assert(arrowBatches.length === 2)
|
||||
val schema = testData2.schema
|
||||
|
||||
val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json")
|
||||
|
@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
Files.write(json1, tempFile1, StandardCharsets.UTF_8)
|
||||
Files.write(json2, tempFile2, StandardCharsets.UTF_8)
|
||||
|
||||
validateConversion(schema, arrowPayloads(0), tempFile1)
|
||||
validateConversion(schema, arrowPayloads(1), tempFile2)
|
||||
validateConversion(schema, arrowBatches(0), tempFile1)
|
||||
validateConversion(schema, arrowBatches(1), tempFile2)
|
||||
}
|
||||
|
||||
test("empty frame collect") {
|
||||
val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect()
|
||||
assert(arrowPayload.isEmpty)
|
||||
val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect()
|
||||
assert(arrowBatches.isEmpty)
|
||||
|
||||
val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i")
|
||||
val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect()
|
||||
assert(filteredArrowPayload.isEmpty)
|
||||
val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect()
|
||||
assert(filteredArrowBatches.isEmpty)
|
||||
}
|
||||
|
||||
test("empty partition collect") {
|
||||
val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i")
|
||||
val arrowPayloads = emptyPart.toArrowPayload.collect()
|
||||
assert(arrowPayloads.length === 1)
|
||||
val arrowBatches = emptyPart.toArrowBatchRdd.collect()
|
||||
assert(arrowBatches.length === 1)
|
||||
val allocator = new RootAllocator(Long.MaxValue)
|
||||
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
|
||||
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
|
||||
assert(arrowRecordBatches.head.getLength == 1)
|
||||
arrowRecordBatches.foreach(_.close())
|
||||
allocator.close()
|
||||
|
@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
val maxRecordsPerBatch = 3
|
||||
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch)
|
||||
val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i")
|
||||
val arrowPayloads = df.toArrowPayload.collect()
|
||||
assert(arrowPayloads.length >= 4)
|
||||
val arrowBatches = df.toArrowBatchRdd.collect()
|
||||
assert(arrowBatches.length >= 4)
|
||||
val allocator = new RootAllocator(Long.MaxValue)
|
||||
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
|
||||
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
|
||||
var recordCount = 0
|
||||
arrowRecordBatches.foreach { batch =>
|
||||
assert(batch.getLength > 0)
|
||||
|
@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
|
||||
}
|
||||
|
||||
runUnsupported { mapData.toDF().toArrowPayload.collect() }
|
||||
runUnsupported { complexData.toArrowPayload.collect() }
|
||||
runUnsupported { mapData.toDF().toArrowBatchRdd.collect() }
|
||||
runUnsupported { complexData.toArrowBatchRdd.collect() }
|
||||
}
|
||||
|
||||
test("test Arrow Validator") {
|
||||
|
@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
}
|
||||
}
|
||||
|
||||
test("roundtrip payloads") {
|
||||
test("roundtrip arrow batches") {
|
||||
val inputRows = (0 until 9).map { i =>
|
||||
InternalRow(i)
|
||||
} :+ InternalRow(null)
|
||||
|
@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
|
||||
|
||||
val ctx = TaskContext.empty()
|
||||
val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx)
|
||||
val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)
|
||||
val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
|
||||
val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx)
|
||||
|
||||
assert(schema == outputRowIter.schema)
|
||||
var count = 0
|
||||
outputRowIter.zipWithIndex.foreach { case (row, i) =>
|
||||
if (i != 9) {
|
||||
assert(row.getInt(0) == i)
|
||||
} else {
|
||||
assert(row.isNullAt(0))
|
||||
}
|
||||
count += 1
|
||||
}
|
||||
|
||||
assert(count == inputRows.length)
|
||||
}
|
||||
|
||||
test("ArrowBatchStreamWriter roundtrip") {
|
||||
val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null)
|
||||
|
||||
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
|
||||
val ctx = TaskContext.empty()
|
||||
val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
|
||||
|
||||
// Write batches to Arrow stream format as a byte array
|
||||
val out = new ByteArrayOutputStream()
|
||||
Utils.tryWithResource(new DataOutputStream(out)) { dataOut =>
|
||||
val writer = new ArrowBatchStreamWriter(schema, dataOut, null)
|
||||
writer.writeBatches(batchIter)
|
||||
writer.end()
|
||||
}
|
||||
|
||||
// Read Arrow stream into batches, then convert back to rows
|
||||
val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray)
|
||||
val readBatches = ArrowConverters.getBatchesFromStream(in)
|
||||
val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx)
|
||||
|
||||
var count = 0
|
||||
outputRowIter.zipWithIndex.foreach { case (row, i) =>
|
||||
|
@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
private def collectAndValidate(
|
||||
df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = {
|
||||
// NOTE: coalesce to single partition because can only load 1 batch in validator
|
||||
val arrowPayload = df.coalesce(1).toArrowPayload.collect().head
|
||||
val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head
|
||||
val tempFile = new File(tempDataPath, file)
|
||||
Files.write(json, tempFile, StandardCharsets.UTF_8)
|
||||
validateConversion(df.schema, arrowPayload, tempFile, timeZoneId)
|
||||
validateConversion(df.schema, batchBytes, tempFile, timeZoneId)
|
||||
}
|
||||
|
||||
private def validateConversion(
|
||||
sparkSchema: StructType,
|
||||
arrowPayload: ArrowPayload,
|
||||
batchBytes: Array[Byte],
|
||||
jsonFile: File,
|
||||
timeZoneId: String = null): Unit = {
|
||||
val allocator = new RootAllocator(Long.MaxValue)
|
||||
|
@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|
|||
|
||||
val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator)
|
||||
val vectorLoader = new VectorLoader(arrowRoot)
|
||||
val arrowRecordBatch = arrowPayload.loadBatch(allocator)
|
||||
val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator)
|
||||
vectorLoader.load(arrowRecordBatch)
|
||||
val jsonRoot = jsonReader.read()
|
||||
Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)
|
||||
|
|
Loading…
Reference in a new issue