[SPARK-26830][SQL][R] Vectorized R dapply() implementation
## What changes were proposed in this pull request? This PR targets to add vectorized `dapply()` in R, Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r df <- createDataFrame(mtcars) collect(dapply(df, function(rdf) { data.frame(rdf$gear + 1) }, structType("gear double"))) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false --driver-memory 4g ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true --driver-memory 4g ``` **R code** ```r rdf <- read.csv("500000.csv") df <- cache(createDataFrame(rdf)) count(df) test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() count(cache(dapply(df, function(rdf) { rdf }, schema(df)))) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 13.42037 mins ``` ``` Time difference of 30.64156 secs ``` The performance improvement was around **2627%**. ### Limitations - For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. - Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later. ## How was this patch tested? Unit tests were added, and manually tested. Closes #23787 from HyukjinKwon/SPARK-26830-1. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
0f2c0b53e8
commit
88bc481b9e
|
@ -1493,6 +1493,29 @@ dapplyInternal <- function(x, func, schema) {
|
|||
schema <- structType(schema)
|
||||
}
|
||||
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
|
||||
if (arrowEnabled) {
|
||||
requireNamespace1 <- requireNamespace
|
||||
if (!requireNamespace1("arrow", quietly = TRUE)) {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
# Currenty Arrow optimization does not support raw for now.
|
||||
# Also, it does not support explicit float type set by users.
|
||||
if (inherits(schema, "structType")) {
|
||||
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) {
|
||||
stop("Arrow optimization with dapply do not support FloatType yet.")
|
||||
}
|
||||
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "BinaryType"))) {
|
||||
stop("Arrow optimization with dapply do not support BinaryType yet.")
|
||||
}
|
||||
} else if (is.null(schema)) {
|
||||
stop(paste0("Arrow optimization does not support 'dapplyCollect' yet. Please disable ",
|
||||
"Arrow optimization or use 'collect' and 'dapply' APIs instead."))
|
||||
} else {
|
||||
stop("'schema' should be DDL-formatted string or structType.")
|
||||
}
|
||||
}
|
||||
|
||||
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
|
||||
connection = NULL)
|
||||
|
||||
|
|
|
@ -247,17 +247,21 @@ readDeserializeInArrow <- function(inputCon) {
|
|||
batches <- RecordBatchStreamReader(arrowData)$batches()
|
||||
|
||||
# Read all groupped batches. Tibble -> data.frame is cheap.
|
||||
data <- lapply(batches, function(batch) as.data.frame(as_tibble(batch)))
|
||||
|
||||
# Read keys to map with each groupped batch.
|
||||
keys <- readMultipleObjects(inputCon)
|
||||
|
||||
list(keys = keys, data = data)
|
||||
lapply(batches, function(batch) as.data.frame(as_tibble(batch)))
|
||||
} else {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
}
|
||||
|
||||
readDeserializeWithKeysInArrow <- function(inputCon) {
|
||||
data <- readDeserializeInArrow(inputCon)
|
||||
|
||||
keys <- readMultipleObjects(inputCon)
|
||||
|
||||
# Read keys to map with each groupped batch later.
|
||||
list(keys = keys, data = data)
|
||||
}
|
||||
|
||||
readRowList <- function(obj) {
|
||||
# readRowList is meant for use inside an lapply. As a result, it is
|
||||
# necessary to open a standalone connection for the row and consume
|
||||
|
|
|
@ -220,3 +220,18 @@ writeArgs <- function(con, args) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeSerializeInArrow <- function(conn, df) {
|
||||
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
|
||||
requireNamespace1 <- requireNamespace
|
||||
if (requireNamespace1("arrow", quietly = TRUE)) {
|
||||
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)
|
||||
|
||||
# There looks no way to send each batch in streaming format via socket
|
||||
# connection. See ARROW-4512.
|
||||
# So, it writes the whole Arrow streaming-formatted binary at once for now.
|
||||
writeRaw(conn, write_arrow(df, raw()))
|
||||
} else {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,6 +76,8 @@ outputResult <- function(serializer, output, outputCon) {
|
|||
SparkR:::writeRawSerialize(outputCon, output)
|
||||
} else if (serializer == "row") {
|
||||
SparkR:::writeRowSerialize(outputCon, output)
|
||||
} else if (serializer == "arrow") {
|
||||
SparkR:::writeSerializeInArrow(outputCon, output)
|
||||
} else {
|
||||
# write lines one-by-one with flag
|
||||
lapply(output, function(line) SparkR:::writeString(outputCon, line))
|
||||
|
@ -172,9 +174,15 @@ if (isEmpty != 0) {
|
|||
} else if (deserializer == "row") {
|
||||
data <- SparkR:::readMultipleObjects(inputCon)
|
||||
} else if (deserializer == "arrow" && mode == 2) {
|
||||
dataWithKeys <- SparkR:::readDeserializeInArrow(inputCon)
|
||||
dataWithKeys <- SparkR:::readDeserializeWithKeysInArrow(inputCon)
|
||||
keys <- dataWithKeys$keys
|
||||
data <- dataWithKeys$data
|
||||
} else if (deserializer == "arrow" && mode == 1) {
|
||||
data <- SparkR:::readDeserializeInArrow(inputCon)
|
||||
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
|
||||
# rbind.fill might be an anternative to make it faster if plyr is installed.
|
||||
# Also, note that, 'dapply' applies a function to each partition.
|
||||
data <- do.call("rbind", data)
|
||||
}
|
||||
|
||||
# Timing reading input data for execution
|
||||
|
@ -192,7 +200,7 @@ if (isEmpty != 0) {
|
|||
output <- compute(mode, partition, serializer, deserializer, keys[[i]],
|
||||
colNames, computeFunc, data[[i]])
|
||||
computeElap <- elapsedSecs()
|
||||
if (deserializer == "arrow") {
|
||||
if (serializer == "arrow") {
|
||||
outputs[[length(outputs) + 1L]] <- output
|
||||
} else {
|
||||
outputResult(serializer, output, outputCon)
|
||||
|
@ -202,22 +210,11 @@ if (isEmpty != 0) {
|
|||
outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap)
|
||||
}
|
||||
|
||||
if (deserializer == "arrow") {
|
||||
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
|
||||
requireNamespace1 <- requireNamespace
|
||||
if (requireNamespace1("arrow", quietly = TRUE)) {
|
||||
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)
|
||||
if (serializer == "arrow") {
|
||||
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
|
||||
# rbind.fill might be an anternative to make it faster if plyr is installed.
|
||||
combined <- do.call("rbind", outputs)
|
||||
|
||||
# Likewise, there looks no way to send each batch in streaming format via socket
|
||||
# connection. See ARROW-4512.
|
||||
# So, it writes the whole Arrow streaming-formatted binary at once for now.
|
||||
SparkR:::writeRaw(outputCon, write_arrow(combined, raw()))
|
||||
} else {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
SparkR:::writeSerializeInArrow(outputCon, combined)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -3317,6 +3317,105 @@ test_that("dapplyCollect() on DataFrame with a binary column", {
|
|||
|
||||
})
|
||||
|
||||
test_that("dapply() Arrow optimization", {
|
||||
skip_if_not_installed("arrow")
|
||||
df <- createDataFrame(mtcars)
|
||||
|
||||
conf <- callJMethod(sparkSession, "conf")
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
|
||||
tryCatch({
|
||||
ret <- dapply(df,
|
||||
function(rdf) {
|
||||
stopifnot(class(rdf) == "data.frame")
|
||||
rdf
|
||||
},
|
||||
schema(df))
|
||||
expected <- collect(ret)
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
|
||||
tryCatch({
|
||||
ret <- dapply(df,
|
||||
function(rdf) {
|
||||
stopifnot(class(rdf) == "data.frame")
|
||||
# mtcars' hp is more then 50.
|
||||
stopifnot(all(rdf$hp > 50))
|
||||
rdf
|
||||
},
|
||||
schema(df))
|
||||
actual <- collect(ret)
|
||||
expect_equal(actual, expected)
|
||||
expect_equal(count(ret), nrow(mtcars))
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("dapply() Arrow optimization - type specification", {
|
||||
skip_if_not_installed("arrow")
|
||||
# Note that regular dapply() seems not supporting date and timestamps
|
||||
# whereas Arrow-optimized dapply() does.
|
||||
rdf <- data.frame(list(list(a = 1,
|
||||
b = "a",
|
||||
c = TRUE,
|
||||
d = 1.1,
|
||||
e = 1L)))
|
||||
# numPartitions are set to 8 intentionally to test empty partitions as well.
|
||||
df <- createDataFrame(rdf, numPartitions = 8)
|
||||
|
||||
conf <- callJMethod(sparkSession, "conf")
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
|
||||
tryCatch({
|
||||
ret <- dapply(df, function(rdf) { rdf }, schema(df))
|
||||
expected <- collect(ret)
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
|
||||
tryCatch({
|
||||
ret <- dapply(df, function(rdf) { rdf }, schema(df))
|
||||
actual <- collect(ret)
|
||||
expect_equal(actual, expected)
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("dapply() Arrow optimization - type specification (date and timestamp)", {
|
||||
skip_if_not_installed("arrow")
|
||||
rdf <- data.frame(list(list(a = as.Date("1990-02-24"),
|
||||
b = as.POSIXct("1990-02-24 12:34:56"))))
|
||||
df <- createDataFrame(rdf)
|
||||
|
||||
conf <- callJMethod(sparkSession, "conf")
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
|
||||
tryCatch({
|
||||
ret <- dapply(df, function(rdf) { rdf }, schema(df))
|
||||
expect_equal(collect(ret), rdf)
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("repartition by columns on DataFrame", {
|
||||
# The tasks here launch R workers with shuffles. So, we decrease the number of shuffle
|
||||
# partitions to reduce the number of the tasks to speed up the test. This is particularly
|
||||
|
|
|
@ -123,16 +123,25 @@ object MapPartitionsInR {
|
|||
schema: StructType,
|
||||
encoder: ExpressionEncoder[Row],
|
||||
child: LogicalPlan): LogicalPlan = {
|
||||
if (SQLConf.get.arrowEnabled) {
|
||||
MapPartitionsInRWithArrow(
|
||||
func,
|
||||
packageNames,
|
||||
broadcastVars,
|
||||
encoder.schema,
|
||||
schema.toAttributes,
|
||||
child)
|
||||
} else {
|
||||
val deserialized = CatalystSerde.deserialize(child)(encoder)
|
||||
val mapped = MapPartitionsInR(
|
||||
CatalystSerde.serialize(MapPartitionsInR(
|
||||
func,
|
||||
packageNames,
|
||||
broadcastVars,
|
||||
encoder.schema,
|
||||
schema,
|
||||
CatalystSerde.generateObjAttr(RowEncoder(schema)),
|
||||
deserialized)
|
||||
CatalystSerde.serialize(mapped)(RowEncoder(schema))
|
||||
deserialized))(RowEncoder(schema))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -154,6 +163,28 @@ case class MapPartitionsInR(
|
|||
outputObjAttr, child)
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar with `MapPartitionsInR` but serializes and deserializes input/output in
|
||||
* Arrow format.
|
||||
*
|
||||
* This is somewhat similar with `org.apache.spark.sql.execution.python.ArrowEvalPython`
|
||||
*/
|
||||
case class MapPartitionsInRWithArrow(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
inputSchema: StructType,
|
||||
output: Seq[Attribute],
|
||||
child: LogicalPlan) extends UnaryNode {
|
||||
// This operator always need all columns of its child, even it doesn't reference to.
|
||||
override def references: AttributeSet = child.outputSet
|
||||
|
||||
override protected def stringArgs: Iterator[Any] = Iterator(
|
||||
inputSchema, StructType.fromAttributes(output), child)
|
||||
|
||||
override val producedAttributes = AttributeSet(output)
|
||||
}
|
||||
|
||||
object MapElements {
|
||||
def apply[T : Encoder, U : Encoder](
|
||||
func: AnyRef,
|
||||
|
|
|
@ -599,6 +599,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case logical.FlatMapGroupsInRWithArrow(f, p, b, is, ot, key, grouping, child) =>
|
||||
execution.FlatMapGroupsInRWithArrowExec(
|
||||
f, p, b, is, ot, key, grouping, planLater(child)) :: Nil
|
||||
case logical.MapPartitionsInRWithArrow(f, p, b, is, ot, child) =>
|
||||
execution.MapPartitionsInRWithArrowExec(
|
||||
f, p, b, is, ot, planLater(child)) :: Nil
|
||||
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
|
||||
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
|
||||
case logical.MapElements(f, _, _, objAttr, child) =>
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
|
|||
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.python.BatchIterator
|
||||
import org.apache.spark.sql.execution.r.ArrowRRunner
|
||||
import org.apache.spark.sql.execution.streaming.GroupStateImpl
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -193,6 +194,80 @@ case class MapPartitionsExec(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar with [[MapPartitionsExec]] and
|
||||
* [[org.apache.spark.sql.execution.r.MapPartitionsRWrapper]] but serializes and deserializes
|
||||
* input/output in Arrow format.
|
||||
*
|
||||
* This is somewhat similar with [[org.apache.spark.sql.execution.python.ArrowEvalPythonExec]]
|
||||
*/
|
||||
case class MapPartitionsInRWithArrowExec(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
inputSchema: StructType,
|
||||
output: Seq[Attribute],
|
||||
child: SparkPlan) extends UnaryExecNode {
|
||||
override def producedAttributes: AttributeSet = AttributeSet(output)
|
||||
|
||||
private val batchSize = conf.arrowMaxRecordsPerBatch
|
||||
|
||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
child.execute().mapPartitionsInternal { inputIter =>
|
||||
val outputTypes = schema.map(_.dataType)
|
||||
|
||||
// DO NOT use iter.grouped(). See BatchIterator.
|
||||
val batchIter =
|
||||
if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter)
|
||||
|
||||
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
|
||||
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY)
|
||||
|
||||
// The communication mechanism is as follows:
|
||||
//
|
||||
// JVM side R side
|
||||
//
|
||||
// 1. Internal rows --------> Arrow record batches
|
||||
// 2. Converts each Arrow record batch to each R data frame
|
||||
// 3. Combine R data frames into one R data frame
|
||||
// 4. Computes R native function on the data frame
|
||||
// 5. Converts the R data frame to Arrow record batches
|
||||
// 6. Columnar batches <-------- Arrow record batches
|
||||
// 7. Each row from each batch
|
||||
//
|
||||
// Note that, unlike Python vectorization implementation, R side sends Arrow formatted
|
||||
// binary in a batch due to the limitation of R API. See also ARROW-4512.
|
||||
val columnarBatchIter = runner.compute(batchIter, -1)
|
||||
val outputProject = UnsafeProjection.create(output, output)
|
||||
new Iterator[InternalRow] {
|
||||
|
||||
private var currentIter = if (columnarBatchIter.hasNext) {
|
||||
val batch = columnarBatchIter.next()
|
||||
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
|
||||
assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " +
|
||||
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
|
||||
batch.rowIterator.asScala
|
||||
} else {
|
||||
Iterator.empty
|
||||
}
|
||||
|
||||
override def hasNext: Boolean = currentIter.hasNext || {
|
||||
if (columnarBatchIter.hasNext) {
|
||||
currentIter = columnarBatchIter.next().rowIterator.asScala
|
||||
hasNext
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
override def next(): InternalRow = currentIter.next()
|
||||
}.map(outputProject)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the given function to each input object.
|
||||
* The output of its child must be a single-field row containing the input object.
|
||||
|
@ -473,8 +548,8 @@ case class FlatMapGroupsInRWithArrowExec(
|
|||
child.execute().mapPartitionsInternal { iter =>
|
||||
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
|
||||
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
|
||||
val runner = new ArrowRRunner(
|
||||
func, packageNames, broadcastVars, inputSchema, SQLConf.get.sessionLocalTimeZone)
|
||||
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
|
||||
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
|
||||
|
||||
val groupedByRKey = grouped.map { case (key, rowIter) =>
|
||||
val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
|
||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
|
|||
* This is necessary because sometimes we cannot hold reference of input rows
|
||||
* because the some input rows are mutable and can be reused.
|
||||
*/
|
||||
private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
|
||||
private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int)
|
||||
extends Iterator[Iterator[T]] {
|
||||
|
||||
override def hasNext: Boolean = iter.hasNext
|
||||
|
|
|
@ -45,7 +45,8 @@ class ArrowRRunner(
|
|||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
schema: StructType,
|
||||
timeZoneId: String)
|
||||
timeZoneId: String,
|
||||
mode: Int)
|
||||
extends RRunner[ColumnarBatch](
|
||||
func,
|
||||
"arrow",
|
||||
|
@ -55,13 +56,32 @@ class ArrowRRunner(
|
|||
numPartitions = -1,
|
||||
isDataFrame = true,
|
||||
schema.fieldNames,
|
||||
RRunnerModes.DATAFRAME_GAPPLY) {
|
||||
mode) {
|
||||
|
||||
// TODO: it needs to refactor to share the same code with RRunner, and have separate
|
||||
// ArrowRRunners.
|
||||
private val getNextBatch = {
|
||||
if (mode == RRunnerModes.DATAFRAME_GAPPLY) {
|
||||
// gapply
|
||||
(inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
|
||||
val (key, nextBatch) = inputIterator
|
||||
.asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]].next()
|
||||
keys.append(key)
|
||||
nextBatch
|
||||
}
|
||||
} else {
|
||||
// dapply
|
||||
(inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
|
||||
inputIterator
|
||||
.asInstanceOf[Iterator[Iterator[InternalRow]]].next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected override def writeData(
|
||||
dataOut: DataOutputStream,
|
||||
printOut: PrintStream,
|
||||
iter: Iterator[_]): Unit = if (iter.hasNext) {
|
||||
val inputIterator = iter.asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]]
|
||||
inputIterator: Iterator[_]): Unit = if (inputIterator.hasNext) {
|
||||
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
|
||||
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||
"stdout writer for R", 0, Long.MaxValue)
|
||||
|
@ -75,8 +95,7 @@ class ArrowRRunner(
|
|||
writer.start()
|
||||
|
||||
while (inputIterator.hasNext) {
|
||||
val (key, nextBatch) = inputIterator.next()
|
||||
keys.append(key)
|
||||
val nextBatch: Iterator[InternalRow] = getNextBatch(inputIterator, keys)
|
||||
|
||||
while (nextBatch.hasNext) {
|
||||
arrowWriter.write(nextBatch.next())
|
||||
|
|
Loading…
Reference in a new issue