From 3c15d8b71c1515dd7f599a902d888fa0c33d40f0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 20 Feb 2019 11:35:17 +0800 Subject: [PATCH] [SPARK-26762][SQL][R] Arrow optimization for conversion from Spark DataFrame to R DataFrame ## What changes were proposed in this pull request? This PR targets to support Arrow optimization for conversion from Spark DataFrame to R DataFrame. Like PySpark side, it falls back to non-optimization code path when it's unable to use Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r collect(createDataFrame(mtcars)) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false --driver-memory 4g ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true --driver-memory 4g ``` **R code** ```r df <- cache(createDataFrame(read.csv("500000.csv"))) count(df) test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() collect(df) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 221.32014 secs ``` ``` Time difference of 15.51145 secs ``` The performance improvement was around **1426%**. ### Limitations: - For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. In this case, we decide to fall back to non-optimization code path. - Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later. ## How was this patch tested? Existing tests related with Arrow optimization cover this change. Also, manually tested. Closes #23760 from HyukjinKwon/SPARK-26762. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- R/pkg/R/DataFrame.R | 56 +++++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 21 +++++- .../apache/spark/api/python/PythonRDD.scala | 6 ++ .../scala/org/apache/spark/api/r/RRDD.scala | 9 ++- .../scala/org/apache/spark/sql/Dataset.scala | 72 +++++++++++++++++-- 5 files changed, 153 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 24ed449f2a..fe836bf4c1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1177,11 +1177,67 @@ setMethod("dim", setMethod("collect", signature(x = "SparkDataFrame"), function(x, stringsAsFactors = FALSE) { + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) + useArrow <- FALSE + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" + if (arrowEnabled) { + useArrow <- tryCatch({ + requireNamespace1 <- requireNamespace + if (!requireNamespace1("arrow", quietly = TRUE)) { + stop("'arrow' package should be installed.") + } + # Currenty Arrow optimization does not support raw for now. + # Also, it does not support explicit float type set by users. + if (inherits(schema(x), "structType")) { + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "FloatType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support FloatType yet.")) + } + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "BinaryType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support BinaryType yet.")) + } + } + TRUE + }, error = function(e) { + warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ", + "with Arrow optimization because ", + "'spark.sql.execution.arrow.enabled' is set to true; however, ", + "failed, attempting non-optimization. Reason: ", + e)) + FALSE + }) + } + dtypes <- dtypes(x) ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() + } else if (useArrow) { + requireNamespace1 <- requireNamespace + if (requireNamespace1("arrow", quietly = TRUE)) { + read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) + as_tibble <- get("as_tibble", envir = asNamespace("arrow")) + + portAuth <- callJMethod(x@sdf, "collectAsArrowToR") + port <- portAuth[[1]] + authSecret <- portAuth[[2]] + conn <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) + output <- tryCatch({ + doServerAuth(conn, authSecret) + arrowTable <- read_arrow(readRaw(conn)) + as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors) + }, finally = { + close(conn) + }) + return(output) + } else { + stop("'arrow' package should be installed.") + } } else { # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4d1360b2a0..19c41ef449 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("createDataFrame Arrow optimization", { +test_that("createDataFrame/collect Arrow optimization", { skip_if_not_installed("arrow") conf <- callJMethod(sparkSession, "conf") @@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", { }) }) -test_that("createDataFrame Arrow optimization - type specification", { +test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", { + skip_if_not_installed("arrow") + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)), + collect(createDataFrame(mtcars, numPartitions = 1))) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + +test_that("createDataFrame/collect Arrow optimization - type specification", { skip_if_not_installed("arrow") rdf <- data.frame(list(list(a = 1, b = "a", diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 14ea289e5f..0937a63dad 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging { */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + serveToStream(threadName, authHelper)(writeFunc) + } + + private[spark] def serveToStream( + threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit) + : Array[Any] = { val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => val out = new BufferedOutputStream(s.getOutputStream()) Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1dc61c7eef..04fc6e18c1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, File} +import java.io.{DataInputStream, File, OutputStream} import java.net.Socket import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} @@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag]( lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } -private[r] object RRDD { +private[spark] object RRDD { def createSparkContext( master: String, appName: String, @@ -165,6 +165,11 @@ private[r] object RRDD { JavaRDD[Array[Byte]] = { PythonRDD.readRDDFromFile(jsc, fileName, parallelism) } + + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8a26152271..bd1ae509cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.{CharArrayWriter, DataOutputStream} +import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.QueryPlanningTracker @@ -3198,9 +3199,66 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as Arrow batches and serve stream to PySpark. + * Collect a Dataset as Arrow batches and serve stream to SparkR. It sends + * arrow batches in an ordered manner with buffering. This is inevitable + * due to missing R API that reads batches from socket directly. See ARROW-4512. + * Eventually, this code should be deduplicated by `collectAsArrowToPython`. */ - private[sql] def collectAsArrowToPython(): Array[Any] = { + private[sql] def collectAsArrowToR(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + + withAction("collectAsArrowToR", queryExecution) { plan => + RRDD.serveToStream("serve-Arrow") { outputStream => + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(outputStream) + val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + val batches = buffer.toByteArray + out.writeInt(batches.length) + out.write(batches) + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } + } + } + + /** + * Collect a Dataset as Arrow batches and serve stream to PySpark. It sends + * arrow batches in an un-ordered manner without buffering, and then batch order + * information at the end. The batches should be reordered at Python side. + */ + private[sql] def collectAsArrowToPython: Array[Any] = { val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => @@ -3211,7 +3269,7 @@ class Dataset[T] private[sql]( val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch index in that partition) tuple - val batchOrder = new ArrayBuffer[(Int, Int)]() + val batchOrder = ArrayBuffer.empty[(Int, Int)] var partitionCount = 0 // Handler to eagerly write batches to Python as they arrive, un-ordered @@ -3220,7 +3278,7 @@ class Dataset[T] private[sql]( // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) arrowBatches.indices.foreach { - partition_batch_index => batchOrder.append((index, partition_batch_index)) + partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } partitionCount += 1 @@ -3232,8 +3290,8 @@ class Dataset[T] private[sql]( // Sort by (index of partition, batch index in that partition) tuple to get the // overall_batch_index from 0 to N-1 batches, which can be used to put the // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => - out.writeInt(overall_batch_index) + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) } out.flush() }