[SPARK-13534][PYSPARK] Using Apache Arrow to increase performance of DataFrame.toPandas

## What changes were proposed in this pull request?
Integrate Apache Arrow with Spark to increase performance of `DataFrame.toPandas`.  This has been done by using Arrow to convert data partitions on the executor JVM to Arrow payload byte arrays where they are then served to the Python process.  The Python DataFrame can then collect the Arrow payloads where they are combined and converted to a Pandas DataFrame.  Data types except complex, date, timestamp, and decimal  are currently supported, otherwise an `UnsupportedOperation` exception is thrown.

Additions to Spark include a Scala package private method `Dataset.toArrowPayload` that will convert data partitions in the executor JVM to `ArrowPayload`s as byte arrays so they can be easily served.  A package private class/object `ArrowConverters` that provide data type mappings and conversion routines.  In Python, a private method `DataFrame._collectAsArrow` is added to collect Arrow payloads and a SQLConf "spark.sql.execution.arrow.enable" can be used in `toPandas()` to enable using Arrow (uses the old conversion by default).

## How was this patch tested?
Added a new test suite `ArrowConvertersSuite` that will run tests on conversion of Datasets to Arrow payloads for supported types.  The suite will generate a Dataset and matching Arrow JSON data, then the dataset is converted to an Arrow payload and finally validated against the JSON data.  This will ensure that the schema and data has been converted correctly.

Added PySpark tests to verify the `toPandas` method is producing equal DataFrames with and without pyarrow.  A roundtrip test to ensure the pandas DataFrame produced by pyspark is equal to a one made directly with pandas.

Author: Bryan Cutler <cutlerb@gmail.com>
Author: Li Jin <ice.xelloss@gmail.com>
Author: Li Jin <li.jin@twosigma.com>
Author: Wes McKinney <wes.mckinney@twosigma.com>

Closes #18459 from BryanCutler/toPandas_with_arrow-SPARK-13534.
This commit is contained in:
Bryan Cutler 2017-07-10 15:21:03 -07:00 committed by Holden Karau
parent 2bfd5accdc
commit d03aebbe65
12 changed files with 1859 additions and 13 deletions

View file

@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
exec "$PYSPARK_DRIVER_PYTHON" -m "$1"
exec "$PYSPARK_DRIVER_PYTHON" -m "$@"
exit
fi

View file

@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
arrow-format-0.4.0.jar
arrow-memory-0.4.0.jar
arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
hppc-0.7.1.jar
htrace-core-3.0.4.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar

View file

@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
arrow-format-0.4.0.jar
arrow-memory-0.4.0.jar
arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
hppc-0.7.1.jar
htrace-core-3.1.0-incubating.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar

20
pom.xml
View file

@ -181,6 +181,7 @@
<paranamer.version>2.6</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<arrow.version>0.4.0</arrow.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
@ -1878,6 +1879,25 @@
<artifactId>paranamer</artifactId>
<version>${paranamer.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
<exclusions>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</dependencyManagement>

View file

@ -182,6 +182,23 @@ class FramedSerializer(Serializer):
raise NotImplementedError
class ArrowSerializer(FramedSerializer):
"""
Serializes an Arrow stream.
"""
def dumps(self, obj):
raise NotImplementedError
def loads(self, obj):
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
return reader.read_all()
def __repr__(self):
return "ArrowSerializer"
class BatchedSerializer(Serializer):
"""

View file

@ -29,7 +29,8 @@ import warnings
from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
@ -1710,7 +1711,8 @@ class DataFrame(object):
@since(1.3)
def toPandas(self):
"""Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
This is only available if Pandas is installed and available.
@ -1723,18 +1725,42 @@ class DataFrame(object):
1 5 Bob
"""
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true":
try:
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
return table.to_pandas()
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
msg = "note: pyarrow must be installed and available on calling Python process " \
"if using spark.sql.execution.arrow.enable=true"
raise ImportError("%s\n%s" % (e.message, msg))
else:
dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
if pandas_type is not None:
dtype[field.name] = pandas_type
dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
if pandas_type is not None:
dtype[field.name] = pandas_type
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
return pdf
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
return pdf
def _collectAsArrow(self):
"""
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
and available.
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
port = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(port, ArrowSerializer()))
##########################################################################################
# Pandas compatibility

View file

@ -58,12 +58,21 @@ from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
_have_arrow = False
try:
import pyarrow
_have_arrow = True
except:
# No Arrow, but that's okay, we'll skip those tests
pass
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
@ -2843,6 +2852,73 @@ class DataTypeVerificationTests(unittest.TestCase):
_make_type_verifier(data_type, nullable=False)(obj)
@unittest.skipIf(not _have_arrow, "Arrow not installed")
class ArrowTests(ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)
cls.spark.conf.set("spark.sql.execution.arrow.enable", "true")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0),
("b", 2, 20, 0.4, 4.0),
("c", 3, 30, 0.8, 6.0)]
def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
def test_unsupported_datatype(self):
schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())
def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
self.data)
pdf = df_null.toPandas()
null_counts = pdf.isnull().sum().tolist()
self.assertTrue(all([c == 1 for c in null_counts]))
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
pdf = df.toPandas()
self.spark.conf.set("spark.sql.execution.arrow.enable", "true")
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
def test_pandas_round_trip(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
pdf = pd.DataFrame(data=data_dict)
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
def test_filtered_frame(self):
df = self.spark.range(3).toDF("i")
pdf = df.filter("i < 0").toPandas()
self.assertEqual(len(pdf.columns), 1)
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:

View file

@ -855,6 +855,24 @@ object SQLConf {
.intConf
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
val ARROW_EXECUTION_ENABLE =
buildConf("spark.sql.execution.arrow.enable")
.internal()
.doc("Make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas with the following data types: " +
"StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " +
"LongType, ShortType")
.booleanConf
.createWithDefault(false)
val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
.internal()
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
"to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.")
.intConf
.createWithDefault(10000)
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@ -1115,6 +1133,10 @@ class SQLConf extends Serializable with Logging {
def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO)
def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */

View file

@ -103,6 +103,10 @@
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.xbean</groupId>
<artifactId>xbean-asm5-shaded</artifactId>

View file

@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
@ -2907,6 +2908,16 @@ class Dataset[T] private[sql](
}
}
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
private[sql] def collectAsArrowToPython(): Int = {
withNewExecutionId {
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
}
}
private[sql] def toPythonIterator(): Int = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
@ -2988,4 +2999,13 @@ class Dataset[T] private[sql](
Dataset(sparkSession, logicalPlan)
}
}
/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
queryExecution.toRdd.mapPartitionsInternal { iter =>
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
}
}
}

View file

@ -0,0 +1,429 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels
import scala.collection.JavaConverters._
import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector._
import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.file._
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
/**
* Convert the ArrowPayload to an ArrowRecordBatch.
*/
def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
ArrowConverters.byteArrayToBatch(payload, allocator)
}
/**
* Get the ArrowPayload as a type that can be served to Python.
*/
def asPythonSerializable: Array[Byte] = payload
}
private[sql] object ArrowPayload {
/**
* Create an ArrowPayload from an ArrowRecordBatch and Spark schema.
*/
def apply(
batch: ArrowRecordBatch,
schema: StructType,
allocator: BufferAllocator): ArrowPayload = {
new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator))
}
}
private[sql] object ArrowConverters {
/**
* Map a Spark DataType to ArrowType.
*/
private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = {
dataType match {
case BooleanType => ArrowType.Bool.INSTANCE
case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true)
case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true)
case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case ByteType => new ArrowType.Int(8, true)
case StringType => ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
}
}
/**
* Convert a Spark Dataset schema to Arrow schema.
*/
private[arrow] def schemaToArrowSchema(schema: StructType): Schema = {
val arrowFields = schema.fields.map { f =>
new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava)
}
new Schema(arrowFields.toList.asJava)
}
/**
* Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload
* by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
*/
private[sql] def toPayloadIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int): Iterator[ArrowPayload] = {
new Iterator[ArrowPayload] {
private val _allocator = new RootAllocator(Long.MaxValue)
private var _nextPayload = if (rowIter.nonEmpty) convert() else null
override def hasNext: Boolean = _nextPayload != null
override def next(): ArrowPayload = {
val obj = _nextPayload
if (hasNext) {
if (rowIter.hasNext) {
_nextPayload = convert()
} else {
_allocator.close()
_nextPayload = null
}
}
obj
}
private def convert(): ArrowPayload = {
val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch)
ArrowPayload(batch, schema, _allocator)
}
}
}
/**
* Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed
* or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0,
* then rowIter will be fully consumed.
*/
private def internalRowIterToArrowBatch(
rowIter: Iterator[InternalRow],
schema: StructType,
allocator: BufferAllocator,
maxRecordsPerBatch: Int = 0): ArrowRecordBatch = {
val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) =>
ColumnWriter(field.dataType, ordinal, allocator).init()
}
val writerLength = columnWriters.length
var recordsInBatch = 0
while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) {
val row = rowIter.next()
var i = 0
while (i < writerLength) {
columnWriters(i).write(row)
i += 1
}
recordsInBatch += 1
}
val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip
val buffers = bufferArrays.flatten
val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
val recordBatch = new ArrowRecordBatch(rowLength,
fieldNodes.toList.asJava, buffers.toList.asJava)
buffers.foreach(_.release())
recordBatch
}
/**
* Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed,
* the batch can no longer be used.
*/
private[arrow] def batchToByteArray(
batch: ArrowRecordBatch,
schema: StructType,
allocator: BufferAllocator): Array[Byte] = {
val arrowSchema = ArrowConverters.schemaToArrowSchema(schema)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val out = new ByteArrayOutputStream()
val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
// Write a batch to byte stream, ensure the batch, allocator and writer are closed
Utils.tryWithSafeFinally {
val loader = new VectorLoader(root)
loader.load(batch)
writer.writeBatch() // writeBatch can throw IOException
} {
batch.close()
root.close()
writer.close()
}
out.toByteArray
}
/**
* Convert a byte array to an ArrowRecordBatch.
*/
private[arrow] def byteArrayToBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
val reader = new ArrowFileReader(in, allocator)
// Read a batch from a byte stream, ensure the reader is closed
Utils.tryWithSafeFinally {
val root = reader.getVectorSchemaRoot // throws IOException
val unloader = new VectorUnloader(root)
reader.loadNextBatch() // throws IOException
unloader.getRecordBatch
} {
reader.close()
}
}
}
/**
* Interface for writing InternalRows to Arrow Buffers.
*/
private[arrow] trait ColumnWriter {
def init(): this.type
def write(row: InternalRow): Unit
/**
* Clear the column writer and return the ArrowFieldNode and ArrowBuf.
* This should be called only once after all the data is written.
*/
def finish(): (ArrowFieldNode, Array[ArrowBuf])
}
/**
* Base class for flat arrow column writer, i.e., column without children.
*/
private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int)
extends ColumnWriter {
def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype)
def valueVector: BaseDataValueVector
def valueMutator: BaseMutator
def setNull(): Unit
def setValue(row: InternalRow): Unit
protected var count = 0
protected var nullCount = 0
override def init(): this.type = {
valueVector.allocateNew()
this
}
override def write(row: InternalRow): Unit = {
if (row.isNullAt(ordinal)) {
setNull()
nullCount += 1
} else {
setValue(row)
}
count += 1
}
override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = {
valueMutator.setValueCount(count)
val fieldNode = new ArrowFieldNode(count, nullCount)
val valueBuffers = valueVector.getBuffers(true)
(fieldNode, valueBuffers)
}
}
private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableBitVector
= new NullableBitVector("BooleanValue", getFieldType(dtype), allocator)
override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 )
}
private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableSmallIntVector
= new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator)
override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getShort(ordinal))
}
private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableIntVector
= new NullableIntVector("IntValue", getFieldType(dtype), allocator)
override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getInt(ordinal))
}
private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableBigIntVector
= new NullableBigIntVector("LongValue", getFieldType(dtype), allocator)
override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getLong(ordinal))
}
private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableFloat4Vector
= new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator)
override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getFloat(ordinal))
}
private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableFloat8Vector
= new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator)
override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getDouble(ordinal))
}
private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableUInt1Vector
= new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator)
override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit
= valueMutator.setSafe(count, row.getByte(ordinal))
}
private[arrow] class UTF8StringColumnWriter(
dtype: ArrowType,
ordinal: Int,
allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableVarCharVector
= new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator)
override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit = {
val str = row.getUTF8String(ordinal)
valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes)
}
}
private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator)
override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit = {
val bytes = row.getBinary(ordinal)
valueMutator.setSafe(count, bytes, 0, bytes.length)
}
}
private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableDateDayVector
= new NullableDateDayVector("DateValue", getFieldType(dtype), allocator)
override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit = {
valueMutator.setSafe(count, row.getInt(ordinal))
}
}
private[arrow] class TimeStampColumnWriter(
dtype: ArrowType,
ordinal: Int,
allocator: BufferAllocator)
extends PrimitiveColumnWriter(ordinal) {
override val valueVector: NullableTimeStampMicroVector
= new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator)
override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator
override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow): Unit = {
valueMutator.setSafe(count, row.getLong(ordinal))
}
}
private[arrow] object ColumnWriter {
/**
* Create an Arrow ColumnWriter given the type and ordinal of row.
*/
def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = {
val dtype = ArrowConverters.sparkTypeToArrowType(dataType)
dataType match {
case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator)
case ShortType => new ShortColumnWriter(dtype, ordinal, allocator)
case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator)
case LongType => new LongColumnWriter(dtype, ordinal, allocator)
case FloatType => new FloatColumnWriter(dtype, ordinal, allocator)
case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator)
case ByteType => new ByteColumnWriter(dtype, ordinal, allocator)
case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator)
case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator)
case DateType => new DateColumnWriter(dtype, ordinal, allocator)
case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
}
}
}