[SPARK-26762][SQL][R] Arrow optimization for conversion from Spark DataFrame to R DataFrame
## What changes were proposed in this pull request? This PR targets to support Arrow optimization for conversion from Spark DataFrame to R DataFrame. Like PySpark side, it falls back to non-optimization code path when it's unable to use Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r collect(createDataFrame(mtcars)) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false --driver-memory 4g ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true --driver-memory 4g ``` **R code** ```r df <- cache(createDataFrame(read.csv("500000.csv"))) count(df) test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() collect(df) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 221.32014 secs ``` ``` Time difference of 15.51145 secs ``` The performance improvement was around **1426%**. ### Limitations: - For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. In this case, we decide to fall back to non-optimization code path. - Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later. ## How was this patch tested? Existing tests related with Arrow optimization cover this change. Also, manually tested. Closes #23760 from HyukjinKwon/SPARK-26762. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
ab850c02f7
commit
3c15d8b71c
|
@ -1177,11 +1177,67 @@ setMethod("dim",
|
|||
setMethod("collect",
|
||||
signature(x = "SparkDataFrame"),
|
||||
function(x, stringsAsFactors = FALSE) {
|
||||
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
|
||||
useArrow <- FALSE
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
|
||||
if (arrowEnabled) {
|
||||
useArrow <- tryCatch({
|
||||
requireNamespace1 <- requireNamespace
|
||||
if (!requireNamespace1("arrow", quietly = TRUE)) {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
# Currenty Arrow optimization does not support raw for now.
|
||||
# Also, it does not support explicit float type set by users.
|
||||
if (inherits(schema(x), "structType")) {
|
||||
if (any(sapply(schema(x)$fields(),
|
||||
function(x) x$dataType.toString() == "FloatType"))) {
|
||||
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
|
||||
"DataFrame does not support FloatType yet."))
|
||||
}
|
||||
if (any(sapply(schema(x)$fields(),
|
||||
function(x) x$dataType.toString() == "BinaryType"))) {
|
||||
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
|
||||
"DataFrame does not support BinaryType yet."))
|
||||
}
|
||||
}
|
||||
TRUE
|
||||
}, error = function(e) {
|
||||
warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ",
|
||||
"with Arrow optimization because ",
|
||||
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
|
||||
"failed, attempting non-optimization. Reason: ",
|
||||
e))
|
||||
FALSE
|
||||
})
|
||||
}
|
||||
|
||||
dtypes <- dtypes(x)
|
||||
ncol <- length(dtypes)
|
||||
if (ncol <= 0) {
|
||||
# empty data.frame with 0 columns and 0 rows
|
||||
data.frame()
|
||||
} else if (useArrow) {
|
||||
requireNamespace1 <- requireNamespace
|
||||
if (requireNamespace1("arrow", quietly = TRUE)) {
|
||||
read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE)
|
||||
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
|
||||
|
||||
portAuth <- callJMethod(x@sdf, "collectAsArrowToR")
|
||||
port <- portAuth[[1]]
|
||||
authSecret <- portAuth[[2]]
|
||||
conn <- socketConnection(
|
||||
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
|
||||
output <- tryCatch({
|
||||
doServerAuth(conn, authSecret)
|
||||
arrowTable <- read_arrow(readRaw(conn))
|
||||
as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors)
|
||||
}, finally = {
|
||||
close(conn)
|
||||
})
|
||||
return(output)
|
||||
} else {
|
||||
stop("'arrow' package should be installed.")
|
||||
}
|
||||
} else {
|
||||
# listCols is a list of columns
|
||||
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
|
||||
|
|
|
@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", {
|
|||
unsetHiveContext()
|
||||
})
|
||||
|
||||
test_that("createDataFrame Arrow optimization", {
|
||||
test_that("createDataFrame/collect Arrow optimization", {
|
||||
skip_if_not_installed("arrow")
|
||||
|
||||
conf <- callJMethod(sparkSession, "conf")
|
||||
|
@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", {
|
|||
})
|
||||
})
|
||||
|
||||
test_that("createDataFrame Arrow optimization - type specification", {
|
||||
test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", {
|
||||
skip_if_not_installed("arrow")
|
||||
|
||||
conf <- callJMethod(sparkSession, "conf")
|
||||
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
|
||||
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
|
||||
tryCatch({
|
||||
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
|
||||
collect(createDataFrame(mtcars, numPartitions = 1)))
|
||||
},
|
||||
finally = {
|
||||
# Resetting the conf back to default value
|
||||
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("createDataFrame/collect Arrow optimization - type specification", {
|
||||
skip_if_not_installed("arrow")
|
||||
rdf <- data.frame(list(list(a = 1,
|
||||
b = "a",
|
||||
|
|
|
@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging {
|
|||
*/
|
||||
private[spark] def serveToStream(
|
||||
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
|
||||
serveToStream(threadName, authHelper)(writeFunc)
|
||||
}
|
||||
|
||||
private[spark] def serveToStream(
|
||||
threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
|
||||
: Array[Any] = {
|
||||
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
|
||||
val out = new BufferedOutputStream(s.getOutputStream())
|
||||
Utils.tryWithSafeFinally {
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.api.r
|
||||
|
||||
import java.io.{DataInputStream, File}
|
||||
import java.io.{DataInputStream, File, OutputStream}
|
||||
import java.net.Socket
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.util.{Map => JMap}
|
||||
|
@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag](
|
|||
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
|
||||
}
|
||||
|
||||
private[r] object RRDD {
|
||||
private[spark] object RRDD {
|
||||
def createSparkContext(
|
||||
master: String,
|
||||
appName: String,
|
||||
|
@ -165,6 +165,11 @@ private[r] object RRDD {
|
|||
JavaRDD[Array[Byte]] = {
|
||||
PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
|
||||
}
|
||||
|
||||
private[spark] def serveToStream(
|
||||
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
|
||||
PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.io.{CharArrayWriter, DataOutputStream}
|
||||
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable
|
|||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.java.function._
|
||||
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
|
||||
import org.apache.spark.api.r.RRDD
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.QueryPlanningTracker
|
||||
|
@ -3198,9 +3199,66 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
|
||||
/**
|
||||
* Collect a Dataset as Arrow batches and serve stream to PySpark.
|
||||
* Collect a Dataset as Arrow batches and serve stream to SparkR. It sends
|
||||
* arrow batches in an ordered manner with buffering. This is inevitable
|
||||
* due to missing R API that reads batches from socket directly. See ARROW-4512.
|
||||
* Eventually, this code should be deduplicated by `collectAsArrowToPython`.
|
||||
*/
|
||||
private[sql] def collectAsArrowToPython(): Array[Any] = {
|
||||
private[sql] def collectAsArrowToR(): Array[Any] = {
|
||||
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
|
||||
|
||||
withAction("collectAsArrowToR", queryExecution) { plan =>
|
||||
RRDD.serveToStream("serve-Arrow") { outputStream =>
|
||||
val buffer = new ByteArrayOutputStream()
|
||||
val out = new DataOutputStream(outputStream)
|
||||
val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId)
|
||||
val arrowBatchRdd = toArrowBatchRdd(plan)
|
||||
val numPartitions = arrowBatchRdd.partitions.length
|
||||
|
||||
// Store collection results for worst case of 1 to N-1 partitions
|
||||
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
|
||||
var lastIndex = -1 // index of last partition written
|
||||
|
||||
// Handler to eagerly write partitions to Python in order
|
||||
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
|
||||
// If result is from next partition in order
|
||||
if (index - 1 == lastIndex) {
|
||||
batchWriter.writeBatches(arrowBatches.iterator)
|
||||
lastIndex += 1
|
||||
// Write stored partitions that come next in order
|
||||
while (lastIndex < results.length && results(lastIndex) != null) {
|
||||
batchWriter.writeBatches(results(lastIndex).iterator)
|
||||
results(lastIndex) = null
|
||||
lastIndex += 1
|
||||
}
|
||||
// After last batch, end the stream
|
||||
if (lastIndex == results.length) {
|
||||
batchWriter.end()
|
||||
val batches = buffer.toByteArray
|
||||
out.writeInt(batches.length)
|
||||
out.write(batches)
|
||||
}
|
||||
} else {
|
||||
// Store partitions received out of order
|
||||
results(index - 1) = arrowBatches
|
||||
}
|
||||
}
|
||||
|
||||
sparkSession.sparkContext.runJob(
|
||||
arrowBatchRdd,
|
||||
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
|
||||
0 until numPartitions,
|
||||
handlePartitionBatches)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collect a Dataset as Arrow batches and serve stream to PySpark. It sends
|
||||
* arrow batches in an un-ordered manner without buffering, and then batch order
|
||||
* information at the end. The batches should be reordered at Python side.
|
||||
*/
|
||||
private[sql] def collectAsArrowToPython: Array[Any] = {
|
||||
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
|
||||
|
||||
withAction("collectAsArrowToPython", queryExecution) { plan =>
|
||||
|
@ -3211,7 +3269,7 @@ class Dataset[T] private[sql](
|
|||
val numPartitions = arrowBatchRdd.partitions.length
|
||||
|
||||
// Batches ordered by (index of partition, batch index in that partition) tuple
|
||||
val batchOrder = new ArrayBuffer[(Int, Int)]()
|
||||
val batchOrder = ArrayBuffer.empty[(Int, Int)]
|
||||
var partitionCount = 0
|
||||
|
||||
// Handler to eagerly write batches to Python as they arrive, un-ordered
|
||||
|
@ -3220,7 +3278,7 @@ class Dataset[T] private[sql](
|
|||
// Write all batches (can be more than 1) in the partition, store the batch order tuple
|
||||
batchWriter.writeBatches(arrowBatches.iterator)
|
||||
arrowBatches.indices.foreach {
|
||||
partition_batch_index => batchOrder.append((index, partition_batch_index))
|
||||
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
|
||||
}
|
||||
}
|
||||
partitionCount += 1
|
||||
|
@ -3232,8 +3290,8 @@ class Dataset[T] private[sql](
|
|||
// Sort by (index of partition, batch index in that partition) tuple to get the
|
||||
// overall_batch_index from 0 to N-1 batches, which can be used to put the
|
||||
// transferred batches in the correct order
|
||||
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
|
||||
out.writeInt(overall_batch_index)
|
||||
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
|
||||
out.writeInt(overallBatchIndex)
|
||||
}
|
||||
out.flush()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue