[SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames

## What changes were proposed in this pull request?

This changes the calls of `toPandas()` and `createDataFrame()` to use the Arrow stream format, when Arrow is enabled.  Previously, Arrow data was written to byte arrays where each chunk is an output of the Arrow file format.  This was mainly due to constraints at the time, and caused some overhead by writing the schema/footer on each chunk of data and then having to read multiple Arrow file inputs and concat them together.

Using the Arrow stream format has improved these by increasing performance, lower memory overhead for the average case, and simplified the code.  Here are the details of this change:

**toPandas()**

_Before:_
Spark internal rows are converted to Arrow file format, each group of records is a complete Arrow file which contains the schema and other metadata.  Next a collect is done and an Array of Arrow files is the result.  After that each Arrow file is sent to Python driver which then loads each file and concats them to a single Arrow DataFrame.

_After:_
Spark internal rows are converted to ArrowRecordBatches directly, which is the simplest Arrow component for IPC data transfers.  The driver JVM then immediately starts serving data to Python as an Arrow stream, sending the schema first. It then starts a Spark job with a custom handler that sends Arrow RecordBatches to Python. Partitions arriving in order are sent immediately, and out-of-order partitions are buffered until the ones that precede it come in. This improves performance, simplifies memory usage on executors, and improves the average memory usage on the JVM driver.  Since the order of partitions must be preserved, the worst case is that the first partition will be the last to arrive all data must be buffered in memory until then. This case is no worse that before when doing a full collect.

**createDataFrame()**

_Before:_
A Pandas DataFrame is split into parts and each part is made into an Arrow file.  Then each file is prefixed by the buffer size and written to a temp file.  The temp file is read and each Arrow file is parallelized as a byte array.

_After:_
A Pandas DataFrame is split into parts, then an Arrow stream is written to a temp file where each part is an ArrowRecordBatch.  The temp file is read as a stream and the Arrow messages are examined.  If the message is an ArrowRecordBatch, the data is saved as a byte array.  After reading the file, each ArrowRecordBatch is parallelized as a byte array.  This has slightly more processing than before because we must look each Arrow message to extract the record batches, but performance ends up a litle better.  It is cleaner in the sense that IPC from Python to JVM is done over a single Arrow stream.

## How was this patch tested?

Added new unit tests for the additions to ArrowConverters in Scala, existing tests for Python.

## Performance Tests - toPandas

Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8
measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each.

Test code
```python
df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand())
for i in range(5):
	start = time.time()
	_ = df.toPandas()
	elapsed = time.time() - start
```

Current Master | This PR
---------------------|------------
5.803557 | 5.16207
5.409119 | 5.133671
5.493509 | 5.147513
5.433107 | 5.105243
5.488757 | 5.018685

Avg Master | Avg This PR
------------------|--------------
5.5256098 | 5.1134364

Speedup of **1.08060595**

## Performance Tests - createDataFrame

Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8
measured wall clock time to execute `createDataFrame()` and get the first record. Took the average best time of 5 runs/5 loops each.

Test code
```python
def run():
	pdf = pd.DataFrame(np.random.rand(10000000, 10))
	spark.createDataFrame(pdf).first()

for i in range(6):
	start = time.time()
	run()
	elapsed = time.time() - start
	gc.collect()
	print("Run %d: %f" % (i, elapsed))
```

Current Master | This PR
--------------------|----------
6.234608 | 5.665641
6.32144 | 5.3475
6.527859 | 5.370803
6.95089 | 5.479151
6.235046 | 5.529167

Avg Master | Avg This PR
---------------|----------------
6.4539686 | 5.4784524

Speedup of **1.178064192**

## Memory Improvements

**toPandas()**

The most significant improvement is reduction of the upper bound space complexity in the JVM driver.  Before, the entire dataset was collected in the JVM first before sending it to Python.  With this change, as soon as a partition is collected, the result handler immediately sends it to Python, so the upper bound is the size of the largest partition.  Also, using the Arrow stream format is more efficient because the schema is written once per stream, followed by record batches.  The schema is now only send from driver JVM to Python.  Before, multiple Arrow file formats were used that each contained the schema.  This duplicated schema was created in the executors, sent to the driver JVM, and then Python where all but the first one received are discarded.

I verified the upper bound limit by running a test that would collect data that would exceed the amount of driver JVM memory available.  Using these settings on a standalone cluster:
```
spark.driver.memory 1g
spark.executor.memory 5g
spark.sql.execution.arrow.enabled true
spark.sql.execution.arrow.fallback.enabled false
spark.sql.execution.arrow.maxRecordsPerBatch 0
spark.driver.maxResultSize 2g
```

Test code:
```python
from pyspark.sql.functions import rand
df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand())
df.toPandas()
```

This makes total data size of 33554432×8×4 = 1073741824

With the current master, it fails with OOM but passes using this PR.

**createDataFrame()**

No significant change in memory except that using the stream format instead of separate file formats avoids duplicated the schema, similar to toPandas above.  The process of reading the stream and parallelizing the batches does cause the record batch message metadata to be copied, but it's size is insignificant.

Closes #21546 from BryanCutler/arrow-toPandas-stream-SPARK-23030.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
Bryan Cutler 2018-08-29 15:01:12 +08:00 committed by hyukjinkwon
parent ff8dcc1d4c
commit 82c18c240a
9 changed files with 327 additions and 164 deletions

View file

@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, and the secret for authentication.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
serveToStream(threadName) { out =>
writeIteratorToStream(items, new DataOutputStream(out))
}
}
/**
* Create a socket server and background thread to execute the writeFunc
* with the given OutputStream.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*
* Once a connection comes in, it will execute the block of code and pass in
* the socket output stream.
*
* The thread will terminate after the block of code is executed or any
* exceptions happen.
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging {
val sock = serverSocket.accept()
authHelper.authClient(sock)
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
val out = new BufferedOutputStream(sock.getOutputStream)
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
writeFunc(out)
} {
out.close()
sock.close()

View file

@ -494,10 +494,14 @@ class SparkContext(object):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
def reader_func(temp_filename):
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
jrdd = self._serialize_to_jvm(c, serializer, reader_func)
return RDD(jrdd, self, serializer)
def _serialize_to_jvm(self, data, parallelism, serializer):
def _serialize_to_jvm(self, data, serializer, reader_func):
"""
Calling the Java parallelize() method with an ArrayList is too slow,
because it sends O(n) Py4J commands. As an alternative, serialized
@ -507,8 +511,7 @@ class SparkContext(object):
try:
serializer.dump_stream(data, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
return reader_func(tempFile.name)
finally:
# readRDDFromFile eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)

View file

@ -185,27 +185,31 @@ class FramedSerializer(Serializer):
raise NotImplementedError
class ArrowSerializer(FramedSerializer):
class ArrowStreamSerializer(Serializer):
"""
Serializes bytes as Arrow data with the Arrow file format.
Serializes Arrow record batches as a stream.
"""
def dumps(self, batch):
def dump_stream(self, iterator, stream):
import pyarrow as pa
import io
sink = io.BytesIO()
writer = pa.RecordBatchFileWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
return sink.getvalue()
writer = None
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()
def loads(self, obj):
def load_stream(self, stream):
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
return reader.read_all()
reader = pa.open_stream(stream)
for batch in reader:
yield batch
def __repr__(self):
return "ArrowSerializer"
return "ArrowStreamSerializer"
def _create_batch(series, timezone):

View file

@ -29,7 +29,7 @@ import warnings
from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@ -2118,10 +2118,9 @@ class DataFrame(object):
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
@ -2170,14 +2169,14 @@ class DataFrame(object):
def _collectAsArrow(self):
"""
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
and available.
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
and available on driver and worker Python environments.
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowSerializer()))
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
##########################################################################################
# Pandas compatibility

View file

@ -501,7 +501,7 @@ class SparkSession(object):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.serializers import ArrowStreamSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
@ -539,10 +539,12 @@ class SparkSession(object):
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
self._wrapped._jsqlContext, temp_filename, schema.json())
# Create Spark DataFrame from Arrow stream file, using one batch per partition
jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df

View file

@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
@ -3273,13 +3273,49 @@ class Dataset[T] private[sql](
}
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
* Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
withAction("collectAsArrowToPython", queryExecution) { plan =>
val iter: Iterator[Array[Byte]] =
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
PythonRDD.serveToStream("serve-Arrow") { out =>
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written
// Handler to eagerly write partitions to Python in order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
}
}
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
0 until numPartitions,
handlePartitionBatches)
}
}
}
@ -3386,20 +3422,20 @@ class Dataset[T] private[sql](
}
}
/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(
ArrowConverters.toBatchIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}
// This is only used in tests, for now.
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
toArrowPayload(queryExecution.executedPlan)
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
}

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.api.python
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils {
}
/**
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
* using each serialized ArrowRecordBatch as a partition.
*
* @param payloadRDD A JavaRDD of ArrowPayloads.
* @param schemaString JSON Formatted Schema for ArrowPayloads.
* @param sqlContext The active [[SQLContext]].
* @return The converted [[DataFrame]].
* @param filename File to read the Arrow stream from.
* @param schemaString JSON Formatted Spark schema for Arrow batches.
* @return A new [[DataFrame]].
*/
def arrowPayloadToDataFrame(
payloadRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
def arrowReadStreamFromFile(
sqlContext: SQLContext,
filename: String,
schemaString: String): DataFrame = {
val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
}
}

View file

@ -17,73 +17,75 @@
package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream}
import java.nio.channels.{Channels, SeekableByteChannel}
import scala.collection.JavaConverters._
import org.apache.arrow.flatbuf.MessageHeader
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter}
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ByteBufferOutputStream, Utils}
/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
* Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format.
*/
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {
private[sql] class ArrowBatchStreamWriter(
schema: StructType,
out: OutputStream,
timeZoneId: String) {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val writeChannel = new WriteChannel(Channels.newChannel(out))
// Write the Arrow schema first, before batches
MessageSerializer.serialize(writeChannel, arrowSchema)
/**
* Convert the ArrowPayload to an ArrowRecordBatch.
* Consume iterator to write each serialized ArrowRecordBatch to the stream.
*/
def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
ArrowConverters.byteArrayToBatch(payload, allocator)
def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = {
arrowBatchIter.foreach(writeChannel.write)
}
/**
* Get the ArrowPayload as a type that can be served to Python.
* End the Arrow stream, does not close output stream.
*/
def asPythonSerializable: Array[Byte] = payload
}
/**
* Iterator interface to iterate over Arrow record batches and return rows
*/
private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
/**
* Return the schema loaded from the Arrow record batch being iterated over
*/
def schema: StructType
def end(): Unit = {
ArrowStreamWriter.writeEndOfStream(writeChannel)
}
}
private[sql] object ArrowConverters {
/**
* Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload
* by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
*/
private[sql] def toPayloadIterator(
private[sql] def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int,
timeZoneId: String,
context: TaskContext): Iterator[ArrowPayload] = {
context: TaskContext): Iterator[Array[Byte]] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)
ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val unloader = new VectorUnloader(root)
val arrowWriter = ArrowWriter.create(root)
context.addTaskCompletionListener[Unit] { _ =>
@ -91,7 +93,7 @@ private[sql] object ArrowConverters {
allocator.close()
}
new Iterator[ArrowPayload] {
new Iterator[Array[Byte]] {
override def hasNext: Boolean = rowIter.hasNext || {
root.close()
@ -99,9 +101,9 @@ private[sql] object ArrowConverters {
false
}
override def next(): ArrowPayload = {
override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
val writeChannel = new WriteChannel(Channels.newChannel(out))
Utils.tryWithSafeFinally {
var rowCount = 0
@ -111,45 +113,46 @@ private[sql] object ArrowConverters {
rowCount += 1
}
arrowWriter.finish()
writer.writeBatch()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
batch.close()
} {
arrowWriter.reset()
writer.close()
}
new ArrowPayload(out.toByteArray)
out.toByteArray
}
}
}
/**
* Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator
* and the schema from the first batch of Arrow data read.
* Maps iterator from serialized ArrowRecordBatches to InternalRows.
*/
private[sql] def fromPayloadIterator(
payloadIter: Iterator[ArrowPayload],
context: TaskContext): ArrowRowIterator = {
private[sql] def fromBatchIterator(
arrowBatchIter: Iterator[Array[Byte]],
schema: StructType,
timeZoneId: String,
context: TaskContext): Iterator[InternalRow] = {
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue)
new ArrowRowIterator {
private var reader: ArrowFileReader = null
private var schemaRead = StructType(Seq.empty)
private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
new Iterator[InternalRow] {
private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty
context.addTaskCompletionListener[Unit] { _ =>
closeReader()
root.close()
allocator.close()
}
override def schema: StructType = schemaRead
override def hasNext: Boolean = rowIter.hasNext || {
closeReader()
if (payloadIter.hasNext) {
if (arrowBatchIter.hasNext) {
rowIter = nextBatch()
true
} else {
root.close()
allocator.close()
false
}
@ -157,19 +160,11 @@ private[sql] object ArrowConverters {
override def next(): InternalRow = rowIter.next()
private def closeReader(): Unit = {
if (reader != null) {
reader.close()
reader = null
}
}
private def nextBatch(): Iterator[InternalRow] = {
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
reader = new ArrowFileReader(in, allocator)
reader.loadNextBatch() // throws IOException
val root = reader.getVectorSchemaRoot // throws IOException
schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)
val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator)
val vectorLoader = new VectorLoader(root)
vectorLoader.load(arrowRecordBatch)
arrowRecordBatch.close()
val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
@ -183,34 +178,106 @@ private[sql] object ArrowConverters {
}
/**
* Convert a byte array to an ArrowRecordBatch.
* Load a serialized ArrowRecordBatch.
*/
private[arrow] def byteArrayToBatch(
private[arrow] def loadBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
val reader = new ArrowFileReader(in, allocator)
// Read a batch from a byte stream, ensure the reader is closed
Utils.tryWithSafeFinally {
val root = reader.getVectorSchemaRoot // throws IOException
val unloader = new VectorUnloader(root)
reader.loadNextBatch() // throws IOException
unloader.getRecordBatch
} {
reader.close()
}
val in = new ByteArrayInputStream(batchBytes)
MessageSerializer.deserializeRecordBatch(
new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException
}
/**
* Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
payloadRDD: JavaRDD[Array[Byte]],
arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
val rdd = payloadRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context)
}
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
/**
* Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
*/
private[sql] def readArrowStreamFromFile(
sqlContext: SQLContext,
filename: String): JavaRDD[Array[Byte]] = {
Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
val batches = getBatchesFromStream(fileStream.getChannel).toArray
// Parallelize the record batches to create an RDD
JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
}
}
/**
* Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches.
*/
private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = {
// Iterate over the serialized Arrow RecordBatch messages from a stream
new Iterator[Array[Byte]] {
var batch: Array[Byte] = readNextBatch()
override def hasNext: Boolean = batch != null
override def next(): Array[Byte] = {
val prevBatch = batch
batch = readNextBatch()
prevBatch
}
// This gets the next serialized ArrowRecordBatch by reading message metadata to check if it
// is a RecordBatch message and then returning the complete serialized message which consists
// of a int32 length, serialized message metadata and a serialized RecordBatch message body
def readNextBatch(): Array[Byte] = {
val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in))
if (msgMetadata == null) {
return null
}
// Get the length of the body, which has not been read at this point
val bodyLength = msgMetadata.getMessageBodyLength.toInt
// Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages
if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) {
// Buffer backed output large enough to hold the complete serialized message
val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength)
// Write message metadata to ByteBuffer output stream
MessageSerializer.writeMessageBuffer(
new WriteChannel(Channels.newChannel(bbout)),
msgMetadata.getMessageLength,
msgMetadata.getMessageBuffer)
// Get a zero-copy ByteBuffer with already contains message metadata, must close first
bbout.close()
val bb = bbout.toByteBuffer
bb.position(bbout.getCount())
// Read message body directly into the ByteBuffer to avoid copy, return backed byte array
bb.limit(bb.capacity())
JavaUtils.readFully(in, bb)
bb.array()
} else {
if (bodyLength > 0) {
// Skip message body if not a RecordBatch
in.position(in.position() + bodyLength)
}
// Proceed to next message
readNextBatch()
}
}
}
}
}

View file

@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.arrow
import java.io.File
import java.io.{ByteArrayOutputStream, DataOutputStream, File}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
@ -26,7 +26,7 @@ import com.google.common.io.Files
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.ipc.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator}
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{SparkException, TaskContext}
@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
test("collect to arrow record batch") {
val indexData = (1 to 6).toDF("i")
val arrowPayloads = indexData.toArrowPayload.collect()
assert(arrowPayloads.nonEmpty)
assert(arrowPayloads.length == indexData.rdd.getNumPartitions)
val arrowBatches = indexData.toArrowBatchRdd.collect()
assert(arrowBatches.nonEmpty)
assert(arrowBatches.length == indexData.rdd.getNumPartitions)
val allocator = new RootAllocator(Long.MaxValue)
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
val rowCount = arrowRecordBatches.map(_.getLength).sum
assert(rowCount === indexData.count())
arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0))
@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
|}
""".stripMargin
val arrowPayloads = testData2.toArrowPayload.collect()
// NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload
assert(arrowPayloads.length === 2)
val arrowBatches = testData2.toArrowBatchRdd.collect()
// NOTE: testData2 should have 2 partitions -> 2 arrow batches
assert(arrowBatches.length === 2)
val schema = testData2.schema
val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json")
@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
Files.write(json1, tempFile1, StandardCharsets.UTF_8)
Files.write(json2, tempFile2, StandardCharsets.UTF_8)
validateConversion(schema, arrowPayloads(0), tempFile1)
validateConversion(schema, arrowPayloads(1), tempFile2)
validateConversion(schema, arrowBatches(0), tempFile1)
validateConversion(schema, arrowBatches(1), tempFile2)
}
test("empty frame collect") {
val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect()
assert(arrowPayload.isEmpty)
val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect()
assert(arrowBatches.isEmpty)
val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i")
val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect()
assert(filteredArrowPayload.isEmpty)
val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect()
assert(filteredArrowBatches.isEmpty)
}
test("empty partition collect") {
val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i")
val arrowPayloads = emptyPart.toArrowPayload.collect()
assert(arrowPayloads.length === 1)
val arrowBatches = emptyPart.toArrowBatchRdd.collect()
assert(arrowBatches.length === 1)
val allocator = new RootAllocator(Long.MaxValue)
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
assert(arrowRecordBatches.head.getLength == 1)
arrowRecordBatches.foreach(_.close())
allocator.close()
@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
val maxRecordsPerBatch = 3
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch)
val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i")
val arrowPayloads = df.toArrowPayload.collect()
assert(arrowPayloads.length >= 4)
val arrowBatches = df.toArrowBatchRdd.collect()
assert(arrowBatches.length >= 4)
val allocator = new RootAllocator(Long.MaxValue)
val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator))
var recordCount = 0
arrowRecordBatches.foreach { batch =>
assert(batch.getLength > 0)
@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
}
runUnsupported { mapData.toDF().toArrowPayload.collect() }
runUnsupported { complexData.toArrowPayload.collect() }
runUnsupported { mapData.toDF().toArrowBatchRdd.collect() }
runUnsupported { complexData.toArrowBatchRdd.collect() }
}
test("test Arrow Validator") {
@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}
test("roundtrip payloads") {
test("roundtrip arrow batches") {
val inputRows = (0 until 9).map { i =>
InternalRow(i)
} :+ InternalRow(null)
@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
val ctx = TaskContext.empty()
val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx)
val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)
val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx)
assert(schema == outputRowIter.schema)
var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
if (i != 9) {
assert(row.getInt(0) == i)
} else {
assert(row.isNullAt(0))
}
count += 1
}
assert(count == inputRows.length)
}
test("ArrowBatchStreamWriter roundtrip") {
val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null)
val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
val ctx = TaskContext.empty()
val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx)
// Write batches to Arrow stream format as a byte array
val out = new ByteArrayOutputStream()
Utils.tryWithResource(new DataOutputStream(out)) { dataOut =>
val writer = new ArrowBatchStreamWriter(schema, dataOut, null)
writer.writeBatches(batchIter)
writer.end()
}
// Read Arrow stream into batches, then convert back to rows
val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray)
val readBatches = ArrowConverters.getBatchesFromStream(in)
val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx)
var count = 0
outputRowIter.zipWithIndex.foreach { case (row, i) =>
@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
private def collectAndValidate(
df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
val arrowPayload = df.coalesce(1).toArrowPayload.collect().head
val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head
val tempFile = new File(tempDataPath, file)
Files.write(json, tempFile, StandardCharsets.UTF_8)
validateConversion(df.schema, arrowPayload, tempFile, timeZoneId)
validateConversion(df.schema, batchBytes, tempFile, timeZoneId)
}
private def validateConversion(
sparkSchema: StructType,
arrowPayload: ArrowPayload,
batchBytes: Array[Byte],
jsonFile: File,
timeZoneId: String = null): Unit = {
val allocator = new RootAllocator(Long.MaxValue)
@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator)
val vectorLoader = new VectorLoader(arrowRoot)
val arrowRecordBatch = arrowPayload.loadBatch(allocator)
val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator)
vectorLoader.load(arrowRecordBatch)
val jsonRoot = jsonReader.read()
Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)