[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.

## What changes were proposed in this pull request?

dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame.

The function signature is:

	dapply(df, function(localDF) {}, schema = NULL)

R function input: local data.frame from the partition on local node
R function output: local data.frame

Schema specifies the Row format of the resulting DataFrame. It must match the R function's output.
If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply().

## How was this patch tested?
SparkR unit tests.

Author: Sun Rui <rui.sun@intel.com>
Author: Sun Rui <sunrui2016@gmail.com>

Closes #12493 from sun-rui/SPARK-12919.
This commit is contained in:
Sun Rui 2016-04-29 16:41:07 -07:00 committed by Shivaram Venkataraman
parent d78fbcc3cc
commit 4ae9fe091c
15 changed files with 337 additions and 15 deletions

View file

@ -45,6 +45,7 @@ exportMethods("arrange",
"covar_samp",
"covar_pop",
"crosstab",
"dapply",
"describe",
"dim",
"distinct",

View file

@ -21,6 +21,7 @@
NULL
setOldClass("jobj")
setOldClass("structType")
#' @title S4 class that represents a SparkDataFrame
#' @description DataFrames can be created using functions like \link{createDataFrame},
@ -1125,6 +1126,66 @@ setMethod("summarize",
agg(x, ...)
})
#' dapply
#'
#' Apply a function to each partition of a DataFrame.
#'
#' @param x A SparkDataFrame
#' @param func A function to be applied to each partition of the SparkDataFrame.
#' func should have only one parameter, to which a data.frame corresponds
#' to each partition will be passed.
#' The output of func should be a data.frame.
#' @param schema The schema of the resulting DataFrame after the function is applied.
#' It must match the output of func.
#' @family SparkDataFrame functions
#' @rdname dapply
#' @name dapply
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame (sqlContext, iris)
#' df1 <- dapply(df, function(x) { x }, schema(df))
#' collect(df1)
#'
#' # filter and add a column
#' df <- createDataFrame (
#' sqlContext,
#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
#' c("a", "b", "c"))
#' schema <- structType(structField("a", "integer"), structField("b", "double"),
#' structField("c", "string"), structField("d", "integer"))
#' df1 <- dapply(
#' df,
#' function(x) {
#' y <- x[x[1] > 1, ]
#' y <- cbind(y, y[1] + 1L)
#' },
#' schema)
#' collect(df1)
#' # the result
#' # a b c d
#' # 1 2 2 2 3
#' # 2 3 3 3 4
#' }
setMethod("dapply",
signature(x = "SparkDataFrame", func = "function", schema = "structType"),
function(x, func, schema) {
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)
broadcastArr <- lapply(ls(.broadcastNames),
function(name) { get(name, .broadcastNames) })
sdf <- callJStatic(
"org.apache.spark.sql.api.r.SQLUtils",
"dapply",
x@sdf,
serialize(cleanClosure(func), connection = NULL),
packageNamesArr,
broadcastArr,
schema$jobj)
dataFrame(sdf)
})
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #

View file

@ -446,6 +446,10 @@ setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
#' @export
setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
#' @rdname dapply
#' @export
setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })

View file

@ -2043,6 +2043,46 @@ test_that("Histogram", {
df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100)))
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
})
test_that("dapply() on a DataFrame", {
df <- createDataFrame (
sqlContext,
list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
c("a", "b", "c"))
ldf <- collect(df)
df1 <- dapply(df, function(x) { x }, schema(df))
result <- collect(df1)
expect_identical(ldf, result)
# Filter and add a column
schema <- structType(structField("a", "integer"), structField("b", "double"),
structField("c", "string"), structField("d", "integer"))
df1 <- dapply(
df,
function(x) {
y <- x[x$a > 1, ]
y <- cbind(y, y$a + 1L)
},
schema)
result <- collect(df1)
expected <- ldf[ldf$a > 1, ]
expected$d <- expected$a + 1L
rownames(expected) <- NULL
expect_identical(expected, result)
# Remove the added column
df2 <- dapply(
df1,
function(x) {
x[, c("a", "b", "c")]
},
schema(df))
result <- collect(df2)
expected <- expected[, c("a", "b", "c")]
expect_identical(expected, result)
})
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)

View file

@ -84,6 +84,13 @@ broadcastElap <- elapsedSecs()
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
isDataFrame <- as.logical(SparkR:::readInt(inputCon))
# If isDataFrame, then read column names
if (isDataFrame) {
colNames <- SparkR:::readObject(inputCon)
}
isEmpty <- SparkR:::readInt(inputCon)
if (isEmpty != 0) {
@ -100,7 +107,34 @@ if (isEmpty != 0) {
# Timing reading input data for execution
inputElap <- elapsedSecs()
output <- computeFunc(partition, data)
if (isDataFrame) {
if (deserializer == "row") {
# Transform the list of rows into a data.frame
# Note that the optional argument stringsAsFactors for rbind is
# available since R 3.2.4. So we set the global option here.
oldOpt <- getOption("stringsAsFactors")
options(stringsAsFactors = FALSE)
data <- do.call(rbind.data.frame, data)
options(stringsAsFactors = oldOpt)
names(data) <- colNames
} else {
# Check to see if data is a valid data.frame
stopifnot(deserializer == "byte")
stopifnot(class(data) == "data.frame")
}
output <- computeFunc(data)
if (serializer == "row") {
# Transform the result data.frame back to a list of rows
output <- split(output, seq(nrow(output)))
} else {
# Serialize the ouput to a byte array
stopifnot(serializer == "byte")
}
} else {
output <- computeFunc(partition, data)
}
# Timing computing
computeElap <- elapsedSecs()

View file

@ -46,7 +46,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)
runner.compute(parentIterator, partition.index, context)
runner.compute(parentIterator, partition.index)
}
}

View file

@ -38,7 +38,9 @@ private[spark] class RRunner[U](
serializer: String,
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
numPartitions: Int = -1)
numPartitions: Int = -1,
isDataFrame: Boolean = false,
colNames: Array[String] = null)
extends Logging {
private var bootTime: Double = _
private var dataStream: DataInputStream = _
@ -53,8 +55,7 @@ private[spark] class RRunner[U](
def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[U] = {
partitionIndex: Int): Iterator[U] = {
// Timing start
bootTime = System.currentTimeMillis / 1000.0
@ -148,6 +149,12 @@ private[spark] class RRunner[U](
dataOut.writeInt(numPartitions)
dataOut.writeInt(if (isDataFrame) 1 else 0)
if (isDataFrame) {
SerDe.writeObject(dataOut, colNames)
}
if (!iter.hasNext) {
dataOut.writeInt(0)
} else {

View file

@ -459,7 +459,7 @@ private[spark] object SerDe {
}
private[r] object SerializationFormats {
private[spark] object SerializationFormats {
val BYTE = "byte"
val STRING = "string"
val ROW = "row"

View file

@ -1147,6 +1147,11 @@ parquetFile <- read.parquet(sqlContext, "people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
registerTempTable(parquetFile, "parquetFile")
teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
schema <- structType(structField("name", "string"))
teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema)
for (teenName in collect(teenNames)$name) {
cat(teenName, "\n")
}
{% endhighlight %}
</div>

View file

@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
// A workaround for SPARK-14803. Remove this after it is fixed.
if (d.outputObjectType.isInstanceOf[ObjectType] &&
d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
s.child
} else {
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
}
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)

View file

@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types._
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@ -29,13 +30,26 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}
def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoder.deserializer)
DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
}
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
SerializeFromObject(encoder.namedExpressions, child)
}
def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}
def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
}
}
/**
@ -106,6 +120,42 @@ case class MapPartitions(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
object MapPartitionsInR {
def apply(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize(child, encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
deserialized)
CatalystSerde.serialize(mapped, RowEncoder(schema))
}
}
/**
* A relation produced by applying a serialized R function `func` to each partition of the `child`.
*
*/
case class MapPartitionsInR(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
inputSchema: StructType,
outputSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
override lazy val schema = outputSchema
}
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,

View file

@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
@ -1980,6 +1981,23 @@ class Dataset[T] private[sql](
mapPartitions(func)(encoder)
}
/**
* Returns a new [[DataFrame]] that contains the result of applying a serialized R function
* `func` to each partition.
*
* @group func
*/
private[sql] def mapPartitionsInR(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType): DataFrame = {
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
Dataset.ofRows(
sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
}
/**
* :: Experimental ::
* (Scala-specific)

View file

@ -23,12 +23,15 @@ import scala.util.matching.Regex
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types._
private[r] object SQLUtils {
private[sql] object SQLUtils {
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
@ -111,7 +114,7 @@ private[r] object SQLUtils {
}
}
private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
private[sql] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
@ -120,7 +123,7 @@ private[r] object SQLUtils {
}.toSeq)
}
private[this] def rowToRBytes(row: Row): Array[Byte] = {
private[sql] def rowToRBytes(row: Row): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
@ -129,6 +132,29 @@ private[r] object SQLUtils {
bos.toByteArray()
}
// Schema for DataFrame of serialized R data
// TODO: introduce a user defined type for serialized R data.
val SERIALIZED_R_DATA_SCHEMA = StructType(Seq(StructField("R", BinaryType)))
/**
* The helper function for dapply() on R side.
*/
def dapply(
df: DataFrame,
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Object],
schema: StructType): DataFrame = {
val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])
val realSchema =
if (schema == null) {
SERIALIZED_R_DATA_SCHEMA
} else {
schema
}
df.mapPartitionsInR(func, packageNames, bv, realSchema)
}
def dfToCols(df: DataFrame): Array[Array[Any]] = {
val localDF: Array[Row] = df.collect()
val numCols = df.columns.length

View file

@ -307,6 +307,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
execution.MapPartitionsExec(
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
case logical.MapElements(f, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.r
import org.apache.spark.api.r.RRunner
import org.apache.spark.api.r.SerializationFormats
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
/**
* A function wrapper that applies the given R function to each partition.
*/
private[sql] case class MapPartitionsRWrapper(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
inputSchema: StructType,
outputSchema: StructType) extends (Iterator[Any] => Iterator[Any]) {
def apply(iter: Iterator[Any]): Iterator[Any] = {
// If the content of current DataFrame is serialized R data?
val isSerializedRData =
if (inputSchema == SERIALIZED_R_DATA_SCHEMA) true else false
val (newIter, deserializer, colNames) =
if (!isSerializedRData) {
// Serialize each row into an byte array that can be deserialized in the R worker
(iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)},
SerializationFormats.ROW, inputSchema.fieldNames)
} else {
(iter.asInstanceOf[Iterator[Row]].map { row => row(0) }, SerializationFormats.BYTE, null)
}
val serializer = if (outputSchema != SERIALIZED_R_DATA_SCHEMA) {
SerializationFormats.ROW
} else {
SerializationFormats.BYTE
}
val runner = new RRunner[Array[Byte]](
func, deserializer, serializer, packageNames, broadcastVars,
isDataFrame = true, colNames = colNames)
// Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
val outputIter = runner.compute(newIter, -1)
if (serializer == SerializationFormats.ROW) {
outputIter.map { bytes => bytesToRow(bytes, outputSchema) }
} else {
outputIter.map { bytes => Row.fromSeq(Seq(bytes)) }
}
}
}