[SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde.
This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui <rui.sun@intel.com> Closes #8276 from sun-rui/SPARK-10048.
This commit is contained in:
parent
16a2be1a84
commit
71a138cd0e
|
@ -652,18 +652,49 @@ setMethod("dim",
|
|||
setMethod("collect",
|
||||
signature(x = "DataFrame"),
|
||||
function(x, stringsAsFactors = FALSE) {
|
||||
# listCols is a list of raw vectors, one per column
|
||||
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
|
||||
cols <- lapply(listCols, function(col) {
|
||||
objRaw <- rawConnection(col)
|
||||
numRows <- readInt(objRaw)
|
||||
col <- readCol(objRaw, numRows)
|
||||
close(objRaw)
|
||||
col
|
||||
})
|
||||
names(cols) <- columns(x)
|
||||
do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors))
|
||||
})
|
||||
names <- columns(x)
|
||||
ncol <- length(names)
|
||||
if (ncol <= 0) {
|
||||
# empty data.frame with 0 columns and 0 rows
|
||||
data.frame()
|
||||
} else {
|
||||
# listCols is a list of columns
|
||||
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
|
||||
stopifnot(length(listCols) == ncol)
|
||||
|
||||
# An empty data.frame with 0 columns and number of rows as collected
|
||||
nrow <- length(listCols[[1]])
|
||||
if (nrow <= 0) {
|
||||
df <- data.frame()
|
||||
} else {
|
||||
df <- data.frame(row.names = 1 : nrow)
|
||||
}
|
||||
|
||||
# Append columns one by one
|
||||
for (colIndex in 1 : ncol) {
|
||||
# Note: appending a column of list type into a data.frame so that
|
||||
# data of complex type can be held. But getting a cell from a column
|
||||
# of list type returns a list instead of a vector. So for columns of
|
||||
# non-complex type, append them as vector.
|
||||
col <- listCols[[colIndex]]
|
||||
if (length(col) <= 0) {
|
||||
df[[names[colIndex]]] <- col
|
||||
} else {
|
||||
# TODO: more robust check on column of primitive types
|
||||
vec <- do.call(c, col)
|
||||
if (class(vec) != "list") {
|
||||
df[[names[colIndex]]] <- vec
|
||||
} else {
|
||||
# For columns of complex type, be careful to access them.
|
||||
# Get a column of complex type returns a list.
|
||||
# Get a cell from a column of complex type returns a list instead of a vector.
|
||||
df[[names[colIndex]]] <- col
|
||||
}
|
||||
}
|
||||
}
|
||||
df
|
||||
}
|
||||
})
|
||||
|
||||
#' Limit
|
||||
#'
|
||||
|
|
|
@ -48,6 +48,7 @@ readTypedObject <- function(con, type) {
|
|||
"r" = readRaw(con),
|
||||
"D" = readDate(con),
|
||||
"t" = readTime(con),
|
||||
"a" = readArray(con),
|
||||
"l" = readList(con),
|
||||
"n" = NULL,
|
||||
"j" = getJobj(readString(con)),
|
||||
|
@ -85,8 +86,7 @@ readTime <- function(con) {
|
|||
as.POSIXct(t, origin = "1970-01-01")
|
||||
}
|
||||
|
||||
# We only support lists where all elements are of same type
|
||||
readList <- function(con) {
|
||||
readArray <- function(con) {
|
||||
type <- readType(con)
|
||||
len <- readInt(con)
|
||||
if (len > 0) {
|
||||
|
@ -100,6 +100,25 @@ readList <- function(con) {
|
|||
}
|
||||
}
|
||||
|
||||
# Read a list. Types of each element may be different.
|
||||
# Null objects are read as NA.
|
||||
readList <- function(con) {
|
||||
len <- readInt(con)
|
||||
if (len > 0) {
|
||||
l <- vector("list", len)
|
||||
for (i in 1:len) {
|
||||
elem <- readObject(con)
|
||||
if (is.null(elem)) {
|
||||
elem <- NA
|
||||
}
|
||||
l[[i]] <- elem
|
||||
}
|
||||
l
|
||||
} else {
|
||||
list()
|
||||
}
|
||||
}
|
||||
|
||||
readRaw <- function(con) {
|
||||
dataLen <- readInt(con)
|
||||
readBin(con, raw(), as.integer(dataLen), endian = "big")
|
||||
|
@ -132,18 +151,19 @@ readDeserialize <- function(con) {
|
|||
}
|
||||
}
|
||||
|
||||
readDeserializeRows <- function(inputCon) {
|
||||
# readDeserializeRows will deserialize a DataOutputStream composed of
|
||||
# a list of lists. Since the DOS is one continuous stream and
|
||||
# the number of rows varies, we put the readRow function in a while loop
|
||||
# that termintates when the next row is empty.
|
||||
readMultipleObjects <- function(inputCon) {
|
||||
# readMultipleObjects will read multiple continuous objects from
|
||||
# a DataOutputStream. There is no preceding field telling the count
|
||||
# of the objects, so the number of objects varies, we try to read
|
||||
# all objects in a loop until the end of the stream.
|
||||
data <- list()
|
||||
while(TRUE) {
|
||||
row <- readRow(inputCon)
|
||||
if (length(row) == 0) {
|
||||
# If reaching the end of the stream, type returned should be "".
|
||||
type <- readType(inputCon)
|
||||
if (type == "") {
|
||||
break
|
||||
}
|
||||
data[[length(data) + 1L]] <- row
|
||||
data[[length(data) + 1L]] <- readTypedObject(inputCon, type)
|
||||
}
|
||||
data # this is a list of named lists now
|
||||
}
|
||||
|
@ -155,35 +175,5 @@ readRowList <- function(obj) {
|
|||
# deserialize the row.
|
||||
rawObj <- rawConnection(obj, "r+")
|
||||
on.exit(close(rawObj))
|
||||
readRow(rawObj)
|
||||
}
|
||||
|
||||
readRow <- function(inputCon) {
|
||||
numCols <- readInt(inputCon)
|
||||
if (length(numCols) > 0 && numCols > 0) {
|
||||
lapply(1:numCols, function(x) {
|
||||
obj <- readObject(inputCon)
|
||||
if (is.null(obj)) {
|
||||
NA
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
}) # each row is a list now
|
||||
} else {
|
||||
list()
|
||||
}
|
||||
}
|
||||
|
||||
# Take a single column as Array[Byte] and deserialize it into an atomic vector
|
||||
readCol <- function(inputCon, numRows) {
|
||||
if (numRows > 0) {
|
||||
# sapply can not work with POSIXlt
|
||||
do.call(c, lapply(1:numRows, function(x) {
|
||||
value <- readObject(inputCon)
|
||||
# Replace NULL with NA so we can coerce to vectors
|
||||
if (is.null(value)) NA else value
|
||||
}))
|
||||
} else {
|
||||
vector()
|
||||
}
|
||||
readObject(rawObj)
|
||||
}
|
||||
|
|
|
@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) {
|
|||
serializeRow <- function(row) {
|
||||
rawObj <- rawConnection(raw(0), "wb")
|
||||
on.exit(close(rawObj))
|
||||
writeRow(rawObj, row)
|
||||
writeGenericList(rawObj, row)
|
||||
rawConnectionValue(rawObj)
|
||||
}
|
||||
|
||||
writeRow <- function(con, row) {
|
||||
numCols <- length(row)
|
||||
writeInt(con, numCols)
|
||||
for (i in 1:numCols) {
|
||||
writeObject(con, row[[i]])
|
||||
}
|
||||
}
|
||||
|
||||
writeRaw <- function(con, batch) {
|
||||
writeInt(con, length(batch))
|
||||
writeBin(batch, con, endian = "big")
|
||||
|
|
77
R/pkg/inst/tests/test_Serde.R
Normal file
77
R/pkg/inst/tests/test_Serde.R
Normal file
|
@ -0,0 +1,77 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
context("SerDe functionality")
|
||||
|
||||
sc <- sparkR.init()
|
||||
|
||||
test_that("SerDe of primitive types", {
|
||||
x <- callJStatic("SparkRHandler", "echo", 1L)
|
||||
expect_equal(x, 1L)
|
||||
expect_equal(class(x), "integer")
|
||||
|
||||
x <- callJStatic("SparkRHandler", "echo", 1)
|
||||
expect_equal(x, 1)
|
||||
expect_equal(class(x), "numeric")
|
||||
|
||||
x <- callJStatic("SparkRHandler", "echo", TRUE)
|
||||
expect_true(x)
|
||||
expect_equal(class(x), "logical")
|
||||
|
||||
x <- callJStatic("SparkRHandler", "echo", "abc")
|
||||
expect_equal(x, "abc")
|
||||
expect_equal(class(x), "character")
|
||||
})
|
||||
|
||||
test_that("SerDe of list of primitive types", {
|
||||
x <- list(1L, 2L, 3L)
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
expect_equal(class(y[[1]]), "integer")
|
||||
|
||||
x <- list(1, 2, 3)
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
expect_equal(class(y[[1]]), "numeric")
|
||||
|
||||
x <- list(TRUE, FALSE)
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
expect_equal(class(y[[1]]), "logical")
|
||||
|
||||
x <- list("a", "b", "c")
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
expect_equal(class(y[[1]]), "character")
|
||||
|
||||
# Empty list
|
||||
x <- list()
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
})
|
||||
|
||||
test_that("SerDe of list of lists", {
|
||||
x <- list(list(1L, 2L, 3L), list(1, 2, 3),
|
||||
list(TRUE, FALSE), list("a", "b", "c"))
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
|
||||
# List of empty lists
|
||||
x <- list(list(), list())
|
||||
y <- callJStatic("SparkRHandler", "echo", x)
|
||||
expect_equal(x, y)
|
||||
})
|
|
@ -94,7 +94,7 @@ if (isEmpty != 0) {
|
|||
} else if (deserializer == "string") {
|
||||
data <- as.list(readLines(inputCon))
|
||||
} else if (deserializer == "row") {
|
||||
data <- SparkR:::readDeserializeRows(inputCon)
|
||||
data <- SparkR:::readMultipleObjects(inputCon)
|
||||
}
|
||||
# Timing reading input data for execution
|
||||
inputElap <- elapsedSecs()
|
||||
|
@ -120,7 +120,7 @@ if (isEmpty != 0) {
|
|||
} else if (deserializer == "string") {
|
||||
data <- readLines(inputCon)
|
||||
} else if (deserializer == "row") {
|
||||
data <- SparkR:::readDeserializeRows(inputCon)
|
||||
data <- SparkR:::readMultipleObjects(inputCon)
|
||||
}
|
||||
# Timing reading input data for execution
|
||||
inputElap <- elapsedSecs()
|
||||
|
|
|
@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend)
|
|||
|
||||
if (objId == "SparkRHandler") {
|
||||
methodName match {
|
||||
// This function is for test-purpose only
|
||||
case "echo" =>
|
||||
val args = readArgs(numArgs, dis)
|
||||
assert(numArgs == 1)
|
||||
|
||||
writeInt(dos, 0)
|
||||
writeObject(dos, args(0))
|
||||
case "stopBackend" =>
|
||||
writeInt(dos, 0)
|
||||
writeType(dos, "void")
|
||||
|
|
|
@ -149,6 +149,10 @@ private[spark] object SerDe {
|
|||
case 'b' => readBooleanArr(dis)
|
||||
case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
|
||||
case 'r' => readBytesArr(dis)
|
||||
case 'l' => {
|
||||
val len = readInt(dis)
|
||||
(0 until len).map(_ => readList(dis)).toArray
|
||||
}
|
||||
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
|
||||
}
|
||||
}
|
||||
|
@ -200,6 +204,9 @@ private[spark] object SerDe {
|
|||
case "date" => dos.writeByte('D')
|
||||
case "time" => dos.writeByte('t')
|
||||
case "raw" => dos.writeByte('r')
|
||||
// Array of primitive types
|
||||
case "array" => dos.writeByte('a')
|
||||
// Array of objects
|
||||
case "list" => dos.writeByte('l')
|
||||
case "jobj" => dos.writeByte('j')
|
||||
case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
|
||||
|
@ -211,26 +218,35 @@ private[spark] object SerDe {
|
|||
writeType(dos, "void")
|
||||
} else {
|
||||
value.getClass.getName match {
|
||||
case "java.lang.Character" =>
|
||||
writeType(dos, "character")
|
||||
writeString(dos, value.asInstanceOf[Character].toString)
|
||||
case "java.lang.String" =>
|
||||
writeType(dos, "character")
|
||||
writeString(dos, value.asInstanceOf[String])
|
||||
case "long" | "java.lang.Long" =>
|
||||
case "java.lang.Long" =>
|
||||
writeType(dos, "double")
|
||||
writeDouble(dos, value.asInstanceOf[Long].toDouble)
|
||||
case "float" | "java.lang.Float" =>
|
||||
case "java.lang.Float" =>
|
||||
writeType(dos, "double")
|
||||
writeDouble(dos, value.asInstanceOf[Float].toDouble)
|
||||
case "decimal" | "java.math.BigDecimal" =>
|
||||
case "java.math.BigDecimal" =>
|
||||
writeType(dos, "double")
|
||||
val javaDecimal = value.asInstanceOf[java.math.BigDecimal]
|
||||
writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble)
|
||||
case "double" | "java.lang.Double" =>
|
||||
case "java.lang.Double" =>
|
||||
writeType(dos, "double")
|
||||
writeDouble(dos, value.asInstanceOf[Double])
|
||||
case "int" | "java.lang.Integer" =>
|
||||
case "java.lang.Byte" =>
|
||||
writeType(dos, "integer")
|
||||
writeInt(dos, value.asInstanceOf[Byte].toInt)
|
||||
case "java.lang.Short" =>
|
||||
writeType(dos, "integer")
|
||||
writeInt(dos, value.asInstanceOf[Short].toInt)
|
||||
case "java.lang.Integer" =>
|
||||
writeType(dos, "integer")
|
||||
writeInt(dos, value.asInstanceOf[Int])
|
||||
case "boolean" | "java.lang.Boolean" =>
|
||||
case "java.lang.Boolean" =>
|
||||
writeType(dos, "logical")
|
||||
writeBoolean(dos, value.asInstanceOf[Boolean])
|
||||
case "java.sql.Date" =>
|
||||
|
@ -242,43 +258,48 @@ private[spark] object SerDe {
|
|||
case "java.sql.Timestamp" =>
|
||||
writeType(dos, "time")
|
||||
writeTime(dos, value.asInstanceOf[Timestamp])
|
||||
|
||||
// Handle arrays
|
||||
|
||||
// Array of primitive types
|
||||
|
||||
// Special handling for byte array
|
||||
case "[B" =>
|
||||
writeType(dos, "raw")
|
||||
writeBytes(dos, value.asInstanceOf[Array[Byte]])
|
||||
// TODO: Types not handled right now include
|
||||
// byte, char, short, float
|
||||
|
||||
// Handle arrays
|
||||
case "[Ljava.lang.String;" =>
|
||||
writeType(dos, "list")
|
||||
writeStringArr(dos, value.asInstanceOf[Array[String]])
|
||||
case "[C" =>
|
||||
writeType(dos, "array")
|
||||
writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString))
|
||||
case "[S" =>
|
||||
writeType(dos, "array")
|
||||
writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt))
|
||||
case "[I" =>
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "array")
|
||||
writeIntArr(dos, value.asInstanceOf[Array[Int]])
|
||||
case "[J" =>
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "array")
|
||||
writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
|
||||
case "[F" =>
|
||||
writeType(dos, "array")
|
||||
writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble))
|
||||
case "[D" =>
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "array")
|
||||
writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
|
||||
case "[Z" =>
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "array")
|
||||
writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
|
||||
case "[[B" =>
|
||||
|
||||
// Array of objects, null objects use "void" type
|
||||
case c if c.startsWith("[") =>
|
||||
writeType(dos, "list")
|
||||
writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
|
||||
case otherName =>
|
||||
// Handle array of objects
|
||||
if (otherName.startsWith("[L")) {
|
||||
val objArr = value.asInstanceOf[Array[Object]]
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "jobj")
|
||||
dos.writeInt(objArr.length)
|
||||
objArr.foreach(o => writeJObj(dos, o))
|
||||
} else {
|
||||
writeType(dos, "jobj")
|
||||
writeJObj(dos, value)
|
||||
}
|
||||
val array = value.asInstanceOf[Array[Object]]
|
||||
writeInt(dos, array.length)
|
||||
array.foreach(elem => writeObject(dos, elem))
|
||||
|
||||
case _ =>
|
||||
writeType(dos, "jobj")
|
||||
writeJObj(dos, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -350,11 +371,6 @@ private[spark] object SerDe {
|
|||
value.foreach(v => writeString(out, v))
|
||||
}
|
||||
|
||||
def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
|
||||
writeType(out, "raw")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => writeBytes(out, v))
|
||||
}
|
||||
}
|
||||
|
||||
private[r] object SerializationFormats {
|
||||
|
|
|
@ -98,27 +98,17 @@ private[r] object SQLUtils {
|
|||
val bos = new ByteArrayOutputStream()
|
||||
val dos = new DataOutputStream(bos)
|
||||
|
||||
SerDe.writeInt(dos, row.length)
|
||||
(0 until row.length).map { idx =>
|
||||
val obj: Object = row(idx).asInstanceOf[Object]
|
||||
SerDe.writeObject(dos, obj)
|
||||
}
|
||||
val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
|
||||
SerDe.writeObject(dos, cols)
|
||||
bos.toByteArray()
|
||||
}
|
||||
|
||||
def dfToCols(df: DataFrame): Array[Array[Byte]] = {
|
||||
def dfToCols(df: DataFrame): Array[Array[Any]] = {
|
||||
// localDF is Array[Row]
|
||||
val localDF = df.collect()
|
||||
val numCols = df.columns.length
|
||||
// dfCols is Array[Array[Any]]
|
||||
val dfCols = convertRowsToColumns(localDF, numCols)
|
||||
|
||||
dfCols.map { col =>
|
||||
colToRBytes(col)
|
||||
}
|
||||
}
|
||||
|
||||
def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
|
||||
// result is Array[Array[Any]]
|
||||
(0 until numCols).map { colIdx =>
|
||||
localDF.map { row =>
|
||||
row(colIdx)
|
||||
|
@ -126,20 +116,6 @@ private[r] object SQLUtils {
|
|||
}.toArray
|
||||
}
|
||||
|
||||
def colToRBytes(col: Array[Any]): Array[Byte] = {
|
||||
val numRows = col.length
|
||||
val bos = new ByteArrayOutputStream()
|
||||
val dos = new DataOutputStream(bos)
|
||||
|
||||
SerDe.writeInt(dos, numRows)
|
||||
|
||||
col.map { item =>
|
||||
val obj: Object = item.asInstanceOf[Object]
|
||||
SerDe.writeObject(dos, obj)
|
||||
}
|
||||
bos.toByteArray()
|
||||
}
|
||||
|
||||
def saveMode(mode: String): SaveMode = {
|
||||
mode match {
|
||||
case "append" => SaveMode.Append
|
||||
|
|
Loading…
Reference in a new issue