[SPARK-26762][SQL][R] Arrow optimization for conversion from Spark DataFrame to R DataFrame

## What changes were proposed in this pull request?

This PR targets to support Arrow optimization for conversion from Spark DataFrame to R DataFrame.
Like PySpark side, it falls back to non-optimization code path when it's unable to use Arrow optimization.

This can be tested as below:

```bash
$ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true
```

```r
collect(createDataFrame(mtcars))
```

### Requirements
  - R 3.5.x
  - Arrow package 0.12+
    ```bash
    Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")'
    ```

**Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204.
**Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204.

### Benchmarks

**Shall**

```bash
sync && sudo purge
./bin/sparkR --conf spark.sql.execution.arrow.enabled=false --driver-memory 4g
```

```bash
sync && sudo purge
./bin/sparkR --conf spark.sql.execution.arrow.enabled=true --driver-memory 4g
```

**R code**

```r
df <- cache(createDataFrame(read.csv("500000.csv")))
count(df)

test <- function() {
  options(digits.secs = 6) # milliseconds
  start.time <- Sys.time()
  collect(df)
  end.time <- Sys.time()
  time.taken <- end.time - start.time
  print(time.taken)
}

test()
```

**Data (350 MB):**

```r
object.size(read.csv("500000.csv"))
350379504 bytes
```

"500000 Records"  http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/

**Results**

```
Time difference of 221.32014 secs
```

```
Time difference of 15.51145 secs
```

The performance improvement was around **1426%**.

### Limitations:

- For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. In this case, we decide to fall back to non-optimization code path.

- Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later.

## How was this patch tested?

Existing tests related with Arrow optimization cover this change. Also, manually tested.

Closes #23760 from HyukjinKwon/SPARK-26762.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Hyukjin Kwon 2019-02-20 11:35:17 +08:00
parent ab850c02f7
commit 3c15d8b71c
5 changed files with 153 additions and 11 deletions

View file

@ -1177,11 +1177,67 @@ setMethod("dim",
setMethod("collect",
signature(x = "SparkDataFrame"),
function(x, stringsAsFactors = FALSE) {
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
useArrow <- FALSE
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
if (arrowEnabled) {
useArrow <- tryCatch({
requireNamespace1 <- requireNamespace
if (!requireNamespace1("arrow", quietly = TRUE)) {
stop("'arrow' package should be installed.")
}
# Currenty Arrow optimization does not support raw for now.
# Also, it does not support explicit float type set by users.
if (inherits(schema(x), "structType")) {
if (any(sapply(schema(x)$fields(),
function(x) x$dataType.toString() == "FloatType"))) {
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
"DataFrame does not support FloatType yet."))
}
if (any(sapply(schema(x)$fields(),
function(x) x$dataType.toString() == "BinaryType"))) {
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
"DataFrame does not support BinaryType yet."))
}
}
TRUE
}, error = function(e) {
warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ",
"with Arrow optimization because ",
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
"failed, attempting non-optimization. Reason: ",
e))
FALSE
})
}
dtypes <- dtypes(x)
ncol <- length(dtypes)
if (ncol <= 0) {
# empty data.frame with 0 columns and 0 rows
data.frame()
} else if (useArrow) {
requireNamespace1 <- requireNamespace
if (requireNamespace1("arrow", quietly = TRUE)) {
read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE)
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
portAuth <- callJMethod(x@sdf, "collectAsArrowToR")
port <- portAuth[[1]]
authSecret <- portAuth[[2]]
conn <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
output <- tryCatch({
doServerAuth(conn, authSecret)
arrowTable <- read_arrow(readRaw(conn))
as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors)
}, finally = {
close(conn)
})
return(output)
} else {
stop("'arrow' package should be installed.")
}
} else {
# listCols is a list of columns
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)

View file

@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", {
unsetHiveContext()
})
test_that("createDataFrame Arrow optimization", {
test_that("createDataFrame/collect Arrow optimization", {
skip_if_not_installed("arrow")
conf <- callJMethod(sparkSession, "conf")
@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", {
})
})
test_that("createDataFrame Arrow optimization - type specification", {
test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", {
skip_if_not_installed("arrow")
conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
collect(createDataFrame(mtcars, numPartitions = 1)))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})
test_that("createDataFrame/collect Arrow optimization - type specification", {
skip_if_not_installed("arrow")
rdf <- data.frame(list(list(a = 1,
b = "a",

View file

@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging {
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
serveToStream(threadName, authHelper)(writeFunc)
}
private[spark] def serveToStream(
threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
: Array[Any] = {
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {

View file

@ -17,7 +17,7 @@
package org.apache.spark.api.r
import java.io.{DataInputStream, File}
import java.io.{DataInputStream, File, OutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Map => JMap}
@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag](
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}
private[r] object RRDD {
private[spark] object RRDD {
def createSparkContext(
master: String,
appName: String,
@ -165,6 +165,11 @@ private[r] object RRDD {
JavaRDD[Array[Byte]] = {
PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
}
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
}
}
/**

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql
import java.io.{CharArrayWriter, DataOutputStream}
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.QueryPlanningTracker
@ -3198,9 +3199,66 @@ class Dataset[T] private[sql](
}
/**
* Collect a Dataset as Arrow batches and serve stream to PySpark.
* Collect a Dataset as Arrow batches and serve stream to SparkR. It sends
* arrow batches in an ordered manner with buffering. This is inevitable
* due to missing R API that reads batches from socket directly. See ARROW-4512.
* Eventually, this code should be deduplicated by `collectAsArrowToPython`.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
private[sql] def collectAsArrowToR(): Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
withAction("collectAsArrowToR", queryExecution) { plan =>
RRDD.serveToStream("serve-Arrow") { outputStream =>
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written
// Handler to eagerly write partitions to Python in order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
val batches = buffer.toByteArray
out.writeInt(batches.length)
out.write(batches)
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
}
}
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
0 until numPartitions,
handlePartitionBatches)
}
}
}
/**
* Collect a Dataset as Arrow batches and serve stream to PySpark. It sends
* arrow batches in an un-ordered manner without buffering, and then batch order
* information at the end. The batches should be reordered at Python side.
*/
private[sql] def collectAsArrowToPython: Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
withAction("collectAsArrowToPython", queryExecution) { plan =>
@ -3211,7 +3269,7 @@ class Dataset[T] private[sql](
val numPartitions = arrowBatchRdd.partitions.length
// Batches ordered by (index of partition, batch index in that partition) tuple
val batchOrder = new ArrayBuffer[(Int, Int)]()
val batchOrder = ArrayBuffer.empty[(Int, Int)]
var partitionCount = 0
// Handler to eagerly write batches to Python as they arrive, un-ordered
@ -3220,7 +3278,7 @@ class Dataset[T] private[sql](
// Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
arrowBatches.indices.foreach {
partition_batch_index => batchOrder.append((index, partition_batch_index))
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
}
}
partitionCount += 1
@ -3232,8 +3290,8 @@ class Dataset[T] private[sql](
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
out.writeInt(overall_batch_index)
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
out.flush()
}