[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.
## What changes were proposed in this pull request? dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame. The function signature is: dapply(df, function(localDF) {}, schema = NULL) R function input: local data.frame from the partition on local node R function output: local data.frame Schema specifies the Row format of the resulting DataFrame. It must match the R function's output. If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply(). ## How was this patch tested? SparkR unit tests. Author: Sun Rui <rui.sun@intel.com> Author: Sun Rui <sunrui2016@gmail.com> Closes #12493 from sun-rui/SPARK-12919.
This commit is contained in:
parent
d78fbcc3cc
commit
4ae9fe091c
|
@ -45,6 +45,7 @@ exportMethods("arrange",
|
|||
"covar_samp",
|
||||
"covar_pop",
|
||||
"crosstab",
|
||||
"dapply",
|
||||
"describe",
|
||||
"dim",
|
||||
"distinct",
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
NULL
|
||||
|
||||
setOldClass("jobj")
|
||||
setOldClass("structType")
|
||||
|
||||
#' @title S4 class that represents a SparkDataFrame
|
||||
#' @description DataFrames can be created using functions like \link{createDataFrame},
|
||||
|
@ -1125,6 +1126,66 @@ setMethod("summarize",
|
|||
agg(x, ...)
|
||||
})
|
||||
|
||||
#' dapply
|
||||
#'
|
||||
#' Apply a function to each partition of a DataFrame.
|
||||
#'
|
||||
#' @param x A SparkDataFrame
|
||||
#' @param func A function to be applied to each partition of the SparkDataFrame.
|
||||
#' func should have only one parameter, to which a data.frame corresponds
|
||||
#' to each partition will be passed.
|
||||
#' The output of func should be a data.frame.
|
||||
#' @param schema The schema of the resulting DataFrame after the function is applied.
|
||||
#' It must match the output of func.
|
||||
#' @family SparkDataFrame functions
|
||||
#' @rdname dapply
|
||||
#' @name dapply
|
||||
#' @export
|
||||
#' @examples
|
||||
#' \dontrun{
|
||||
#' df <- createDataFrame (sqlContext, iris)
|
||||
#' df1 <- dapply(df, function(x) { x }, schema(df))
|
||||
#' collect(df1)
|
||||
#'
|
||||
#' # filter and add a column
|
||||
#' df <- createDataFrame (
|
||||
#' sqlContext,
|
||||
#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
|
||||
#' c("a", "b", "c"))
|
||||
#' schema <- structType(structField("a", "integer"), structField("b", "double"),
|
||||
#' structField("c", "string"), structField("d", "integer"))
|
||||
#' df1 <- dapply(
|
||||
#' df,
|
||||
#' function(x) {
|
||||
#' y <- x[x[1] > 1, ]
|
||||
#' y <- cbind(y, y[1] + 1L)
|
||||
#' },
|
||||
#' schema)
|
||||
#' collect(df1)
|
||||
#' # the result
|
||||
#' # a b c d
|
||||
#' # 1 2 2 2 3
|
||||
#' # 2 3 3 3 4
|
||||
#' }
|
||||
setMethod("dapply",
|
||||
signature(x = "SparkDataFrame", func = "function", schema = "structType"),
|
||||
function(x, func, schema) {
|
||||
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
|
||||
connection = NULL)
|
||||
|
||||
broadcastArr <- lapply(ls(.broadcastNames),
|
||||
function(name) { get(name, .broadcastNames) })
|
||||
|
||||
sdf <- callJStatic(
|
||||
"org.apache.spark.sql.api.r.SQLUtils",
|
||||
"dapply",
|
||||
x@sdf,
|
||||
serialize(cleanClosure(func), connection = NULL),
|
||||
packageNamesArr,
|
||||
broadcastArr,
|
||||
schema$jobj)
|
||||
dataFrame(sdf)
|
||||
})
|
||||
|
||||
############################## RDD Map Functions ##################################
|
||||
# All of the following functions mirror the existing RDD map functions, #
|
||||
|
|
|
@ -446,6 +446,10 @@ setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
|
|||
#' @export
|
||||
setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
|
||||
|
||||
#' @rdname dapply
|
||||
#' @export
|
||||
setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
|
||||
|
||||
#' @rdname summary
|
||||
#' @export
|
||||
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
|
||||
|
|
|
@ -2043,6 +2043,46 @@ test_that("Histogram", {
|
|||
df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100)))
|
||||
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
|
||||
})
|
||||
|
||||
test_that("dapply() on a DataFrame", {
|
||||
df <- createDataFrame (
|
||||
sqlContext,
|
||||
list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
|
||||
c("a", "b", "c"))
|
||||
ldf <- collect(df)
|
||||
df1 <- dapply(df, function(x) { x }, schema(df))
|
||||
result <- collect(df1)
|
||||
expect_identical(ldf, result)
|
||||
|
||||
|
||||
# Filter and add a column
|
||||
schema <- structType(structField("a", "integer"), structField("b", "double"),
|
||||
structField("c", "string"), structField("d", "integer"))
|
||||
df1 <- dapply(
|
||||
df,
|
||||
function(x) {
|
||||
y <- x[x$a > 1, ]
|
||||
y <- cbind(y, y$a + 1L)
|
||||
},
|
||||
schema)
|
||||
result <- collect(df1)
|
||||
expected <- ldf[ldf$a > 1, ]
|
||||
expected$d <- expected$a + 1L
|
||||
rownames(expected) <- NULL
|
||||
expect_identical(expected, result)
|
||||
|
||||
# Remove the added column
|
||||
df2 <- dapply(
|
||||
df1,
|
||||
function(x) {
|
||||
x[, c("a", "b", "c")]
|
||||
},
|
||||
schema(df))
|
||||
result <- collect(df2)
|
||||
expected <- expected[, c("a", "b", "c")]
|
||||
expect_identical(expected, result)
|
||||
})
|
||||
|
||||
unlink(parquetPath)
|
||||
unlink(jsonPath)
|
||||
unlink(jsonPathNa)
|
||||
|
|
|
@ -84,6 +84,13 @@ broadcastElap <- elapsedSecs()
|
|||
# as number of partitions to create.
|
||||
numPartitions <- SparkR:::readInt(inputCon)
|
||||
|
||||
isDataFrame <- as.logical(SparkR:::readInt(inputCon))
|
||||
|
||||
# If isDataFrame, then read column names
|
||||
if (isDataFrame) {
|
||||
colNames <- SparkR:::readObject(inputCon)
|
||||
}
|
||||
|
||||
isEmpty <- SparkR:::readInt(inputCon)
|
||||
|
||||
if (isEmpty != 0) {
|
||||
|
@ -100,7 +107,34 @@ if (isEmpty != 0) {
|
|||
# Timing reading input data for execution
|
||||
inputElap <- elapsedSecs()
|
||||
|
||||
output <- computeFunc(partition, data)
|
||||
if (isDataFrame) {
|
||||
if (deserializer == "row") {
|
||||
# Transform the list of rows into a data.frame
|
||||
# Note that the optional argument stringsAsFactors for rbind is
|
||||
# available since R 3.2.4. So we set the global option here.
|
||||
oldOpt <- getOption("stringsAsFactors")
|
||||
options(stringsAsFactors = FALSE)
|
||||
data <- do.call(rbind.data.frame, data)
|
||||
options(stringsAsFactors = oldOpt)
|
||||
|
||||
names(data) <- colNames
|
||||
} else {
|
||||
# Check to see if data is a valid data.frame
|
||||
stopifnot(deserializer == "byte")
|
||||
stopifnot(class(data) == "data.frame")
|
||||
}
|
||||
output <- computeFunc(data)
|
||||
if (serializer == "row") {
|
||||
# Transform the result data.frame back to a list of rows
|
||||
output <- split(output, seq(nrow(output)))
|
||||
} else {
|
||||
# Serialize the ouput to a byte array
|
||||
stopifnot(serializer == "byte")
|
||||
}
|
||||
} else {
|
||||
output <- computeFunc(partition, data)
|
||||
}
|
||||
|
||||
# Timing computing
|
||||
computeElap <- elapsedSecs()
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
|
|||
// The parent may be also an RRDD, so we should launch it first.
|
||||
val parentIterator = firstParent[T].iterator(partition, context)
|
||||
|
||||
runner.compute(parentIterator, partition.index, context)
|
||||
runner.compute(parentIterator, partition.index)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,9 @@ private[spark] class RRunner[U](
|
|||
serializer: String,
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
numPartitions: Int = -1)
|
||||
numPartitions: Int = -1,
|
||||
isDataFrame: Boolean = false,
|
||||
colNames: Array[String] = null)
|
||||
extends Logging {
|
||||
private var bootTime: Double = _
|
||||
private var dataStream: DataInputStream = _
|
||||
|
@ -53,8 +55,7 @@ private[spark] class RRunner[U](
|
|||
|
||||
def compute(
|
||||
inputIterator: Iterator[_],
|
||||
partitionIndex: Int,
|
||||
context: TaskContext): Iterator[U] = {
|
||||
partitionIndex: Int): Iterator[U] = {
|
||||
// Timing start
|
||||
bootTime = System.currentTimeMillis / 1000.0
|
||||
|
||||
|
@ -148,6 +149,12 @@ private[spark] class RRunner[U](
|
|||
|
||||
dataOut.writeInt(numPartitions)
|
||||
|
||||
dataOut.writeInt(if (isDataFrame) 1 else 0)
|
||||
|
||||
if (isDataFrame) {
|
||||
SerDe.writeObject(dataOut, colNames)
|
||||
}
|
||||
|
||||
if (!iter.hasNext) {
|
||||
dataOut.writeInt(0)
|
||||
} else {
|
||||
|
|
|
@ -459,7 +459,7 @@ private[spark] object SerDe {
|
|||
|
||||
}
|
||||
|
||||
private[r] object SerializationFormats {
|
||||
private[spark] object SerializationFormats {
|
||||
val BYTE = "byte"
|
||||
val STRING = "string"
|
||||
val ROW = "row"
|
||||
|
|
|
@ -1147,6 +1147,11 @@ parquetFile <- read.parquet(sqlContext, "people.parquet")
|
|||
# Parquet files can also be registered as tables and then used in SQL statements.
|
||||
registerTempTable(parquetFile, "parquetFile")
|
||||
teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
|
||||
schema <- structType(structField("name", "string"))
|
||||
teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema)
|
||||
for (teenName in collect(teenNames)$name) {
|
||||
cat(teenName, "\n")
|
||||
}
|
||||
{% endhighlight %}
|
||||
|
||||
</div>
|
||||
|
|
|
@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
|
|||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
|
||||
if d.outputObjectType == s.inputObjectType =>
|
||||
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
|
||||
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
|
||||
Project(objAttr :: Nil, s.child)
|
||||
|
||||
// A workaround for SPARK-14803. Remove this after it is fixed.
|
||||
if (d.outputObjectType.isInstanceOf[ObjectType] &&
|
||||
d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
|
||||
s.child
|
||||
} else {
|
||||
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
|
||||
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
|
||||
Project(objAttr :: Nil, s.child)
|
||||
}
|
||||
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
|
||||
if a.deserializer.dataType == s.inputObjectType =>
|
||||
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
|
||||
|
|
|
@ -17,11 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.sql.{Encoder, Row}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{DataType, StructType}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object CatalystSerde {
|
||||
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
|
||||
|
@ -29,13 +30,26 @@ object CatalystSerde {
|
|||
DeserializeToObject(deserializer, generateObjAttr[T], child)
|
||||
}
|
||||
|
||||
def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
|
||||
val deserializer = UnresolvedDeserializer(encoder.deserializer)
|
||||
DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
|
||||
}
|
||||
|
||||
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
|
||||
SerializeFromObject(encoderFor[T].namedExpressions, child)
|
||||
}
|
||||
|
||||
def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
|
||||
SerializeFromObject(encoder.namedExpressions, child)
|
||||
}
|
||||
|
||||
def generateObjAttr[T : Encoder]: Attribute = {
|
||||
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
|
||||
}
|
||||
|
||||
def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
|
||||
AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -106,6 +120,42 @@ case class MapPartitions(
|
|||
outputObjAttr: Attribute,
|
||||
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
|
||||
|
||||
object MapPartitionsInR {
|
||||
def apply(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
schema: StructType,
|
||||
encoder: ExpressionEncoder[Row],
|
||||
child: LogicalPlan): LogicalPlan = {
|
||||
val deserialized = CatalystSerde.deserialize(child, encoder)
|
||||
val mapped = MapPartitionsInR(
|
||||
func,
|
||||
packageNames,
|
||||
broadcastVars,
|
||||
encoder.schema,
|
||||
schema,
|
||||
CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
|
||||
deserialized)
|
||||
CatalystSerde.serialize(mapped, RowEncoder(schema))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A relation produced by applying a serialized R function `func` to each partition of the `child`.
|
||||
*
|
||||
*/
|
||||
case class MapPartitionsInR(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
inputSchema: StructType,
|
||||
outputSchema: StructType,
|
||||
outputObjAttr: Attribute,
|
||||
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
|
||||
override lazy val schema = outputSchema
|
||||
}
|
||||
|
||||
object MapElements {
|
||||
def apply[T : Encoder, U : Encoder](
|
||||
func: AnyRef,
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
|
|||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.java.function._
|
||||
import org.apache.spark.api.python.PythonRDD
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst._
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
|
@ -1980,6 +1981,23 @@ class Dataset[T] private[sql](
|
|||
mapPartitions(func)(encoder)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a new [[DataFrame]] that contains the result of applying a serialized R function
|
||||
* `func` to each partition.
|
||||
*
|
||||
* @group func
|
||||
*/
|
||||
private[sql] def mapPartitionsInR(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
schema: StructType): DataFrame = {
|
||||
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
|
||||
Dataset.ofRows(
|
||||
sparkSession,
|
||||
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* (Scala-specific)
|
||||
|
|
|
@ -23,12 +23,15 @@ import scala.util.matching.Regex
|
|||
|
||||
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
||||
import org.apache.spark.api.r.SerDe
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
private[r] object SQLUtils {
|
||||
private[sql] object SQLUtils {
|
||||
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
|
||||
|
||||
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
|
||||
|
@ -111,7 +114,7 @@ private[r] object SQLUtils {
|
|||
}
|
||||
}
|
||||
|
||||
private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
|
||||
private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
|
||||
val bis = new ByteArrayInputStream(bytes)
|
||||
val dis = new DataInputStream(bis)
|
||||
val num = SerDe.readInt(dis)
|
||||
|
@ -120,7 +123,7 @@ private[r] object SQLUtils {
|
|||
}.toSeq)
|
||||
}
|
||||
|
||||
private[this] def rowToRBytes(row: Row): Array[Byte] = {
|
||||
private[sql] def rowToRBytes(row: Row): Array[Byte] = {
|
||||
val bos = new ByteArrayOutputStream()
|
||||
val dos = new DataOutputStream(bos)
|
||||
|
||||
|
@ -129,6 +132,29 @@ private[r] object SQLUtils {
|
|||
bos.toByteArray()
|
||||
}
|
||||
|
||||
// Schema for DataFrame of serialized R data
|
||||
// TODO: introduce a user defined type for serialized R data.
|
||||
val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType)))
|
||||
|
||||
/**
|
||||
* The helper function for dapply() on R side.
|
||||
*/
|
||||
def dapply(
|
||||
df: DataFrame,
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Object],
|
||||
schema: StructType): DataFrame = {
|
||||
val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
|
||||
val realSchema =
|
||||
if (schema == null) {
|
||||
SERIALIZED_R_DATA_SCHEMA
|
||||
} else {
|
||||
schema
|
||||
}
|
||||
df.mapPartitionsInR(func, packageNames, bv, realSchema)
|
||||
}
|
||||
|
||||
def dfToCols(df: DataFrame): Array[Array[Any]] = {
|
||||
val localDF: Array[Row] = df.collect()
|
||||
val numCols = df.columns.length
|
||||
|
|
|
@ -307,6 +307,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
|
||||
case logical.MapPartitions(f, objAttr, child) =>
|
||||
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
|
||||
case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
|
||||
execution.MapPartitionsExec(
|
||||
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
|
||||
case logical.MapElements(f, objAttr, child) =>
|
||||
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
|
||||
case logical.AppendColumns(f, in, out, child) =>
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.r
|
||||
|
||||
import org.apache.spark.api.r.RRunner
|
||||
import org.apache.spark.api.r.SerializationFormats
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.sql.api.r.SQLUtils._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
|
||||
|
||||
/**
|
||||
* A function wrapper that applies the given R function to each partition.
|
||||
*/
|
||||
private[sql] case class MapPartitionsRWrapper(
|
||||
func: Array[Byte],
|
||||
packageNames: Array[Byte],
|
||||
broadcastVars: Array[Broadcast[Object]],
|
||||
inputSchema: StructType,
|
||||
outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
|
||||
def apply(iter: Iterator[Any]): Iterator[Any] = {
|
||||
// If the content of current DataFrame is serialized R data?
|
||||
val isSerializedRData =
|
||||
if (inputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
|
||||
|
||||
val (newIter, deserializer, colNames) =
|
||||
if (!isSerializedRData) {
|
||||
// Serialize each row into an byte array that can be deserialized in the R worker
|
||||
(iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
|
||||
SerializationFormats.ROW, inputSchema.fieldNames)
|
||||
} else {
|
||||
(iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
|
||||
}
|
||||
|
||||
val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
|
||||
SerializationFormats.ROW
|
||||
} else {
|
||||
SerializationFormats.BYTE
|
||||
}
|
||||
|
||||
val runner = new RRunner[Array[Byte]](
|
||||
func, deserializer, serializer, packageNames, broadcastVars,
|
||||
isDataFrame = true, colNames = colNames)
|
||||
// Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
|
||||
val outputIter = runner.compute(newIter, -1)
|
||||
|
||||
if (serializer == SerializationFormats.ROW) {
|
||||
outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
|
||||
} else {
|
||||
outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue