diff --git a/.gitignore b/.gitignore index d162fa9cca..d54d21b802 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,8 @@ ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml +R-unit-tests.log +R/unit-tests.out # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index 8c61e67a0c..8aca5a7f7a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -67,3 +67,5 @@ logs .*scalastyle-output.xml .*dependency-reduced-pom.xml known_translations +DESCRIPTION +NAMESPACE diff --git a/R/.gitignore b/R/.gitignore new file mode 100644 index 0000000000..9a5889ba28 --- /dev/null +++ b/R/.gitignore @@ -0,0 +1,6 @@ +*.o +*.so +*.Rd +lib +pkg/man +pkg/html diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md new file mode 100644 index 0000000000..931d01549b --- /dev/null +++ b/R/DOCUMENTATION.md @@ -0,0 +1,12 @@ +# SparkR Documentation + +SparkR documentation is generated using in-source comments annotated using using +`roxygen2`. After making changes to the documentation, to generate man pages, +you can run the following from an R console in the SparkR home directory + + library(devtools) + devtools::document(pkg="./pkg", roclets=c("rd")) + +You can verify if your changes are good by running + + R CMD check pkg/ diff --git a/R/README.md b/R/README.md new file mode 100644 index 0000000000..a6970e39b5 --- /dev/null +++ b/R/README.md @@ -0,0 +1,67 @@ +# R on Spark + +SparkR is an R package that provides a light-weight frontend to use Spark from R. + +### SparkR development + +#### Build Spark + +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +``` + build/mvn -DskipTests -Psparkr package +``` + +#### Running sparkR + +You can start using SparkR by launching the SparkR shell with + + ./bin/sparkR + +The `sparkR` script automatically creates a SparkContext with Spark by default in +local mode. To specify the Spark master of a cluster for the automatically created +SparkContext, you can run + + ./bin/sparkR --master "local[2]" + +To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR` + +#### Using SparkR from RStudio + +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +``` +# Set this to where Spark is installed +Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +# This line loads SparkR from the installed directory +.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) +library(SparkR) +sc <- sparkR.init(master="local") +``` + +#### Making changes to SparkR + +The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. + +#### Generating documentation + +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. + +### Examples, Unit tests + +SparkR comes with several sample programs in the `examples/src/main/r` directory. +To run one of them, use `./bin/sparkR `. For example: + + ./bin/sparkR examples/src/main/r/pi.R local[2] + +You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): + + R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' + ./R/run-tests.sh + +### Running on YARN +The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +``` +export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md new file mode 100644 index 0000000000..3f889c0ca3 --- /dev/null +++ b/R/WINDOWS.md @@ -0,0 +1,13 @@ +## Building SparkR on Windows + +To build SparkR on Windows, the following steps are required + +1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to +include Rtools and R in `PATH`. +2. Install +[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +`JAVA_HOME` in the system environment variables. +3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` +directory in Maven in `PATH`. +4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). +5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` diff --git a/R/create-docs.sh b/R/create-docs.sh new file mode 100755 index 0000000000..4194172a2e --- /dev/null +++ b/R/create-docs.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Script to create API docs for SparkR +# This requires `devtools` and `knitr` to be installed on the machine. + +# After running this script the html docs can be found in +# $SPARK_HOME/R/pkg/html + +# Figure out where the script is +export FWDIR="$(cd "`dirname "$0"`"; pwd)" +pushd $FWDIR + +# Generate Rd file +Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' + +# Install the package +./install-dev.sh + +# Now create HTML files + +# knit_rd puts html in current working directory +mkdir -p pkg/html +pushd pkg/html + +Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' + +popd + +popd diff --git a/R/install-dev.bat b/R/install-dev.bat new file mode 100644 index 0000000000..008a5c668b --- /dev/null +++ b/R/install-dev.bat @@ -0,0 +1,27 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Install development version of SparkR +rem + +set SPARK_HOME=%~dp0.. + +MKDIR %SPARK_HOME%\R\lib + +R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ diff --git a/R/install-dev.sh b/R/install-dev.sh new file mode 100755 index 0000000000..55ed6f4be1 --- /dev/null +++ b/R/install-dev.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + + +FWDIR="$(cd `dirname $0`; pwd)" +LIB_DIR="$FWDIR/lib" + +mkdir -p $LIB_DIR + +# Install R +R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ diff --git a/R/log4j.properties b/R/log4j.properties new file mode 100644 index 0000000000..701adb2a3d --- /dev/null +++ b/R/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION new file mode 100644 index 0000000000..1842b97d43 --- /dev/null +++ b/R/pkg/DESCRIPTION @@ -0,0 +1,35 @@ +Package: SparkR +Type: Package +Title: R frontend for Spark +Version: 1.4.0 +Date: 2013-09-09 +Author: The Apache Software Foundation +Maintainer: Shivaram Venkataraman +Imports: + methods +Depends: + R (>= 3.0), + methods, +Suggests: + testthat +Description: R frontend for Spark +License: Apache License (== 2.0) +Collate: + 'generics.R' + 'jobj.R' + 'SQLTypes.R' + 'RDD.R' + 'pairRDD.R' + 'column.R' + 'group.R' + 'DataFrame.R' + 'SQLContext.R' + 'broadcast.R' + 'context.R' + 'deserialize.R' + 'serialize.R' + 'sparkR.R' + 'backend.R' + 'client.R' + 'utils.R' + 'zzz.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE new file mode 100644 index 0000000000..a354cdce74 --- /dev/null +++ b/R/pkg/NAMESPACE @@ -0,0 +1,182 @@ +#exportPattern("^[[:alpha:]]+") +exportClasses("RDD") +exportClasses("Broadcast") +exportMethods( + "aggregateByKey", + "aggregateRDD", + "cache", + "checkpoint", + "coalesce", + "cogroup", + "collect", + "collectAsMap", + "collectPartition", + "combineByKey", + "count", + "countByKey", + "countByValue", + "distinct", + "Filter", + "filterRDD", + "first", + "flatMap", + "flatMapValues", + "fold", + "foldByKey", + "foreach", + "foreachPartition", + "fullOuterJoin", + "glom", + "groupByKey", + "join", + "keyBy", + "keys", + "length", + "lapply", + "lapplyPartition", + "lapplyPartitionsWithIndex", + "leftOuterJoin", + "lookup", + "map", + "mapPartitions", + "mapPartitionsWithIndex", + "mapValues", + "maximum", + "minimum", + "numPartitions", + "partitionBy", + "persist", + "pipeRDD", + "reduce", + "reduceByKey", + "reduceByKeyLocally", + "repartition", + "rightOuterJoin", + "sampleRDD", + "saveAsTextFile", + "saveAsObjectFile", + "sortBy", + "sortByKey", + "sumRDD", + "take", + "takeOrdered", + "takeSample", + "top", + "unionRDD", + "unpersist", + "value", + "values", + "zipRDD", + "zipWithIndex", + "zipWithUniqueId" + ) + +# S3 methods exported +export( + "textFile", + "objectFile", + "parallelize", + "hashCode", + "includePackage", + "broadcast", + "setBroadcastValue", + "setCheckpointDir" + ) +export("sparkR.init") +export("sparkR.stop") +export("print.jobj") +useDynLib(SparkR, stringHashCode) +importFrom(methods, setGeneric, setMethod, setOldClass) + +# SparkRSQL + +exportClasses("DataFrame") + +exportMethods("columns", + "distinct", + "dtypes", + "explain", + "filter", + "groupBy", + "head", + "insertInto", + "intersect", + "isLocal", + "limit", + "orderBy", + "names", + "printSchema", + "registerTempTable", + "repartition", + "sampleDF", + "saveAsParquetFile", + "saveAsTable", + "saveDF", + "schema", + "select", + "selectExpr", + "show", + "showDF", + "sortDF", + "subtract", + "toJSON", + "toRDD", + "unionAll", + "where", + "withColumn", + "withColumnRenamed") + +exportClasses("Column") + +exportMethods("abs", + "alias", + "approxCountDistinct", + "asc", + "avg", + "cast", + "contains", + "countDistinct", + "desc", + "endsWith", + "getField", + "getItem", + "isNotNull", + "isNull", + "last", + "like", + "lower", + "max", + "mean", + "min", + "rlike", + "sqrt", + "startsWith", + "substr", + "sum", + "sumDistinct", + "upper") + +exportClasses("GroupedData") +exportMethods("agg") + +export("sparkRSQL.init", + "sparkRHive.init") + +export("cacheTable", + "clearCache", + "createDataFrame", + "createExternalTable", + "dropTempTable", + "jsonFile", + "jsonRDD", + "loadDF", + "parquetFile", + "sql", + "table", + "tableNames", + "tables", + "toDF", + "uncacheTable") + +export("print.structType", + "print.structField") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R new file mode 100644 index 0000000000..feafd56909 --- /dev/null +++ b/R/pkg/R/DataFrame.R @@ -0,0 +1,1270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DataFrame.R - DataFrame class and methods implemented in S4 OO classes + +#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame +#' @description DataFrames can be created using functions like +#' \code{jsonFile}, \code{table} etc. +#' @rdname DataFrame +#' @seealso jsonFile, table +#' +#' @param env An R environment that stores bookkeeping states of the DataFrame +#' @param sdf A Java object reference to the backing Scala DataFrame +#' @export +setClass("DataFrame", + slots = list(env = "environment", + sdf = "jobj")) + +setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { + .Object@env <- new.env() + .Object@env$isCached <- isCached + + .Object@sdf <- sdf + .Object +}) + +#' @rdname DataFrame +#' @export +dataFrame <- function(sdf, isCached = FALSE) { + new("DataFrame", sdf, isCached) +} + +############################ DataFrame Methods ############################################## + +#' Print Schema of a DataFrame +#' +#' Prints out the schema in tree format +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname printSchema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' printSchema(df) +#'} +setMethod("printSchema", + signature(x = "DataFrame"), + function(x) { + schemaString <- callJMethod(schema(x)$jobj, "treeString") + cat(schemaString) + }) + +#' Get schema object +#' +#' Returns the schema of this DataFrame as a structType object. +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname schema +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dfSchema <- schema(df) +#'} +setMethod("schema", + signature(x = "DataFrame"), + function(x) { + structType(callJMethod(x@sdf, "schema")) + }) + +#' Explain +#' +#' Print the logical and physical Catalyst plans to the console for debugging. +#' +#' @param x A SparkSQL DataFrame +#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @rdname explain +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' explain(df, TRUE) +#'} +setMethod("explain", + signature(x = "DataFrame"), + function(x, extended = FALSE) { + queryExec <- callJMethod(x@sdf, "queryExecution") + if (extended) { + cat(callJMethod(queryExec, "toString")) + } else { + execPlan <- callJMethod(queryExec, "executedPlan") + cat(callJMethod(execPlan, "toString")) + } + }) + +#' isLocal +#' +#' Returns True if the `collect` and `take` methods can be run locally +#' (without any Spark executors). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname isLocal +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' isLocal(df) +#'} +setMethod("isLocal", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "isLocal") + }) + +#' ShowDF +#' +#' Print the first numRows rows of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @param numRows The number of rows to print. Defaults to 20. +#' +#' @rdname showDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' showDF(df) +#'} +setMethod("showDF", + signature(x = "DataFrame"), + function(x, numRows = 20) { + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") + }) + +#' show +#' +#' Print the DataFrame column names and types +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname show +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' show(df) +#'} +setMethod("show", "DataFrame", + function(object) { + cols <- lapply(dtypes(object), function(l) { + paste(l, collapse = ":") + }) + s <- paste(cols, collapse = ", ") + cat(paste("DataFrame[", s, "]\n", sep = "")) + }) + +#' DataTypes +#' +#' Return all column names and their data types as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname dtypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dtypes(df) +#'} +setMethod("dtypes", + signature(x = "DataFrame"), + function(x) { + lapply(schema(x)$fields(), function(f) { + c(f$name(), f$dataType.simpleString()) + }) + }) + +#' Column names +#' +#' Return all column names as a list +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname columns +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' columns(df) +#'} +setMethod("columns", + signature(x = "DataFrame"), + function(x) { + sapply(schema(x)$fields(), function(f) { + f$name() + }) + }) + +#' @rdname columns +#' @export +setMethod("names", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' Register Temporary Table +#' +#' Registers a DataFrame as a Temporary Table in the SQLContext +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' +#' @rdname registerTempTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "json_df") +#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#'} +setMethod("registerTempTable", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName) { + callJMethod(x@sdf, "registerTempTable", tableName) + }) + +#' insertInto +#' +#' Insert the contents of a DataFrame into a table registered in the current SQL Context. +#' +#' @param x A SparkSQL DataFrame +#' @param tableName A character vector containing the name of the table +#' @param overwrite A logical argument indicating whether or not to overwrite +#' the existing rows in the table. +#' +#' @rdname insertInto +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' registerTempTable(df, "table1") +#' insertInto(df2, "table1", overwrite = TRUE) +#'} +setMethod("insertInto", + signature(x = "DataFrame", tableName = "character"), + function(x, tableName, overwrite = FALSE) { + callJMethod(x@sdf, "insertInto", tableName, overwrite) + }) + +#' Cache +#' +#' Persist with the default storage level (MEMORY_ONLY). +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname cache-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' cache(df) +#'} +setMethod("cache", + signature(x = "DataFrame"), + function(x) { + cached <- callJMethod(x@sdf, "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist +#' +#' Persist this DataFrame with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The DataFrame to persist +#' @rdname persist +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#'} +setMethod("persist", + signature(x = "DataFrame", newLevel = "character"), + function(x, newLevel) { + callJMethod(x@sdf, "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist +#' +#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The DataFrame to unpersist +#' @param blocking Whether to block until all blocks are deleted +#' @rdname unpersist-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' persist(df, "MEMORY_AND_DISK") +#' unpersist(df) +#'} +setMethod("unpersist", + signature(x = "DataFrame"), + function(x, blocking = TRUE) { + callJMethod(x@sdf, "unpersist", blocking) + x@env$isCached <- FALSE + x + }) + +#' Repartition +#' +#' Return a new DataFrame that has exactly numPartitions partitions. +#' +#' @param x A SparkSQL DataFrame +#' @param numPartitions The number of partitions to use. +#' @rdname repartition +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- repartition(df, 2L) +#'} +setMethod("repartition", + signature(x = "DataFrame", numPartitions = "numeric"), + function(x, numPartitions) { + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + dataFrame(sdf) + }) + +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @rdname tojson +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newRDD <- toJSON(df) +#'} +setMethod("toJSON", + signature(x = "DataFrame"), + function(x) { + rdd <- callJMethod(x@sdf, "toJSON") + jrdd <- callJMethod(rdd, "toJavaRDD") + RDD(jrdd, serializedMode = "string") + }) + +#' saveAsParquetFile +#' +#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out +#' with this method can be read back in as a DataFrame using parquetFile(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' @rdname saveAsParquetFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#'} +setMethod("saveAsParquetFile", + signature(x = "DataFrame", path = "character"), + function(x, path) { + invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + }) + +#' Distinct +#' +#' Return a new DataFrame containing the distinct rows in this DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @rdname distinct +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' distinctDF <- distinct(df) +#'} +setMethod("distinct", + signature(x = "DataFrame"), + function(x) { + sdf <- callJMethod(x@sdf, "distinct") + dataFrame(sdf) + }) + +#' SampleDF +#' +#' Return a sampled subset of this DataFrame using a random seed. +#' +#' @param x A SparkSQL DataFrame +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @rdname sampleDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collect(sampleDF(df, FALSE, 0.5)) +#' collect(sampleDF(df, TRUE, 0.5)) +#'} +setMethod("sampleDF", + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + dataFrame(sdf) + }) + +#' Count +#' +#' Returns the number of rows in a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname count +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' count(df) +#' } +setMethod("count", + signature(x = "DataFrame"), + function(x) { + callJMethod(x@sdf, "count") + }) + +#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame. +#' +#' @param x A SparkSQL DataFrame +#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' should be converted to factors. FALSE by default. + +#' @rdname collect-methods +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' collected <- collect(df) +#' firstName <- collected[[1]]$name +#' } +setMethod("collect", + signature(x = "DataFrame"), + function(x, stringsAsFactors = FALSE) { + # listCols is a list of raw vectors, one per column + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + cols <- lapply(listCols, function(col) { + objRaw <- rawConnection(col) + numRows <- readInt(objRaw) + col <- readCol(objRaw, numRows) + close(objRaw) + col + }) + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) + }) + +#' Limit +#' +#' Limit the resulting DataFrame to the number of rows specified. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return +#' @return A new DataFrame containing the number of rows specified. +#' +#' @rdname limit +#' @export +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' limitedDF <- limit(df, 10) +#' } +setMethod("limit", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + res <- callJMethod(x@sdf, "limit", as.integer(num)) + dataFrame(res) + }) + +# Take the first NUM rows of a DataFrame and return a the results as a data.frame + +#' @rdname take +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' take(df, 2) +#' } +setMethod("take", + signature(x = "DataFrame", num = "numeric"), + function(x, num) { + limited <- limit(x, num) + collect(limited) + }) + +#' Head +#' +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame +#' convention in R. +#' +#' @param x A SparkSQL DataFrame +#' @param num The number of rows to return. Default is 6. +#' @return A data.frame +#' +#' @rdname head +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' head(df) +#' } +setMethod("head", + signature(x = "DataFrame"), + function(x, num = 6L) { + # Default num is 6L in keeping with R's data.frame convention + take(x, num) + }) + +#' Return the first row of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' first(df) +#' } +setMethod("first", + signature(x = "DataFrame"), + function(x) { + take(x, 1) + }) + +#' toRDD() +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' rdd <- toRDD(df) +#' } +setMethod("toRDD", + signature(x = "DataFrame"), + function(x) { + jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf) + colNames <- callJMethod(x@sdf, "columns") + rdd <- RDD(jrdd, serializedMode = "row") + lapply(rdd, function(row) { + names(row) <- colNames + row + }) + }) + +#' GroupBy +#' +#' Groups the DataFrame using the specified columns, so we can run aggregation on them. +#' +#' @param x a DataFrame +#' @return a GroupedData +#' @seealso GroupedData +#' @rdname DataFrame +#' @export +#' @examples +#' \dontrun{ +#' # Compute the average for all numeric columns grouped by department. +#' avg(groupBy(df, "department")) +#' +#' # Compute the max age and average salary, grouped by department and gender. +#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") +#' } +setMethod("groupBy", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + if (length(cols) >= 1 && class(cols[[1]]) == "character") { + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + } else { + jcol <- lapply(cols, function(c) { c@jc }) + sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + } + groupedData(sgd) + }) + +#' Agg +#' +#' Compute aggregates by specifying a list of columns +#' +#' @rdname DataFrame +#' @export +setMethod("agg", + signature(x = "DataFrame"), + function(x, ...) { + agg(groupBy(x), ...) + }) + + +############################## RDD Map Functions ################################## +# All of the following functions mirror the existing RDD map functions, # +# but allow for use with DataFrames by first converting to an RRDD before calling # +# the requested map function. # +################################################################################### + +#' @rdname lapply +setMethod("lapply", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapply(rdd, FUN) + }) + +#' @rdname lapply +setMethod("map", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' @rdname flatMap +setMethod("flatMap", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + flatMap(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("lapplyPartition", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + rdd <- toRDD(X) + lapplyPartition(rdd, FUN) + }) + +#' @rdname lapplyPartition +setMethod("mapPartitions", + signature(X = "DataFrame", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' @rdname foreach +setMethod("foreach", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreach(rdd, func) + }) + +#' @rdname foreach +setMethod("foreachPartition", + signature(x = "DataFrame", func = "function"), + function(x, func) { + rdd <- toRDD(x) + foreachPartition(rdd, func) + }) + + +############################## SELECT ################################## + +getColumn <- function(x, c) { + column(callJMethod(x@sdf, "col", c)) +} + +#' @rdname select +setMethod("$", signature(x = "DataFrame"), + function(x, name) { + getColumn(x, name) + }) + +setMethod("$<-", signature(x = "DataFrame"), + function(x, name, value) { + stopifnot(class(value) == "Column") + cols <- columns(x) + if (name %in% cols) { + cols <- lapply(cols, function(c) { + if (c == name) { + alias(value, name) + } else { + col(c) + } + }) + nx <- select(x, cols) + } else { + nx <- withColumn(x, name, value) + } + x@sdf <- nx@sdf + x + }) + +#' @rdname select +setMethod("[[", signature(x = "DataFrame"), + function(x, i) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + getColumn(x, i) + }) + +#' @rdname select +setMethod("[", signature(x = "DataFrame", i = "missing"), + function(x, i, j, ...) { + if (is.numeric(j)) { + cols <- columns(x) + j <- cols[j] + } + if (length(j) > 1) { + j <- as.list(j) + } + select(x, j) + }) + +#' Select +#' +#' Selects a set of columns with names or Column expressions. +#' @param x A DataFrame +#' @param col A list of columns or single Column or name +#' @return A new DataFrame with selected columns +#' @export +#' @rdname select +#' @examples +#' \dontrun{ +#' select(df, "*") +#' select(df, "col1", "col2") +#' select(df, df$name, df$age + 1) +#' select(df, c("col1", "col2")) +#' select(df, list(df$name, df$age + 1)) +#' # Columns can also be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' # Similar to R data frames columns can also be selected using `$` +#' df$age +#' } +setMethod("select", signature(x = "DataFrame", col = "character"), + function(x, col, ...) { + sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", signature(x = "DataFrame", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + dataFrame(sdf) + }) + +#' @rdname select +#' @export +setMethod("select", + signature(x = "DataFrame", col = "list"), + function(x, col) { + cols <- lapply(col, function(c) { + if (class(c)== "Column") { + c@jc + } else { + col(c)@jc + } + }) + sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + dataFrame(sdf) + }) + +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + +#' WithColumn +#' +#' Return a new DataFrame with the specified column added. +#' +#' @param x A DataFrame +#' @param colName A string containing the name of the new column. +#' @param col A Column expression. +#' @return A DataFrame with the new column added. +#' @rdname withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' } +setMethod("withColumn", + signature(x = "DataFrame", colName = "character", col = "Column"), + function(x, colName, col) { + select(x, x$"*", alias(col, colName)) + }) + +#' WithColumnRenamed +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param existingCol The name of the column you want to change. +#' @param newCol The new column name. +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- withColumnRenamed(df, "col1", "newCol1") +#' } +setMethod("withColumnRenamed", + signature(x = "DataFrame", existingCol = "character", newCol = "character"), + function(x, existingCol, newCol) { + cols <- lapply(columns(x), function(c) { + if (c == existingCol) { + alias(col(c), newCol) + } else { + col(c) + } + }) + select(x, cols) + }) + +setClassUnion("characterOrColumn", c("character", "Column")) + +#' SortDF +#' +#' Sort a DataFrame by the specified column(s). +#' +#' @param x A DataFrame to be sorted. +#' @param col Either a Column object or character vector indicating the field to sort on +#' @param ... Additional sorting fields +#' @return A DataFrame where all elements are sorted. +#' @rdname sortDF +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' sortDF(df, df$col1) +#' sortDF(df, "col1") +#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' } +setMethod("sortDF", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col, ...) { + if (class(col) == "character") { + sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + } else if (class(col) == "Column") { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + } + dataFrame(sdf) + }) + +#' @rdname sortDF +#' @export +setMethod("orderBy", + signature(x = "DataFrame", col = "characterOrColumn"), + function(x, col) { + sortDF(x, col) + }) + +#' Filter +#' +#' Filter the rows of a DataFrame according to a given condition. +#' +#' @param x A DataFrame to be sorted. +#' @param condition The condition to sort on. This may either be a Column expression +#' or a string containing a SQL statement +#' @return A DataFrame containing only the rows that meet the condition. +#' @rdname filter +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' filter(df, "col1 > 0") +#' filter(df, df$col2 != "abcdefg") +#' } +setMethod("filter", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + if (class(condition) == "Column") { + condition <- condition@jc + } + sdf <- callJMethod(x@sdf, "filter", condition) + dataFrame(sdf) + }) + +#' @rdname filter +#' @export +setMethod("where", + signature(x = "DataFrame", condition = "characterOrColumn"), + function(x, condition) { + filter(x, condition) + }) + +#' Join +#' +#' Join two DataFrames based on the given join expression. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' @param joinType The type of join to perform. The following join types are available: +#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' @return A DataFrame containing the result of the join operation. +#' @rdname join +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' join(df1, df2) # Performs a Cartesian +#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression +#' join(df1, df2, df1$col1 == df2$col2, "right_outer") +#' } +setMethod("join", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y, joinExpr = NULL, joinType = NULL) { + if (is.null(joinExpr)) { + sdf <- callJMethod(x@sdf, "join", y@sdf) + } else { + if (class(joinExpr) != "Column") stop("joinExpr must be a Column") + if (is.null(joinType)) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) + } else { + if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) + } else { + stop("joinType must be one of the following types: ", + "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + } + } + } + dataFrame(sdf) + }) + +#' UnionAll +#' +#' Return a new DataFrame containing the union of rows in this DataFrame +#' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the union. +#' @rdname unionAll +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' unioned <- unionAll(df, df2) +#' } +setMethod("unionAll", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + dataFrame(unioned) + }) + +#' Intersect +#' +#' Return a new DataFrame containing rows only in both this DataFrame +#' and another DataFrame. This is equivalent to `INTERSECT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the intersect. +#' @rdname intersect +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' intersectDF <- intersect(df, df2) +#' } +setMethod("intersect", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + intersected <- callJMethod(x@sdf, "intersect", y@sdf) + dataFrame(intersected) + }) + +#' Subtract +#' +#' Return a new DataFrame containing rows in this DataFrame +#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. +#' +#' @param x A Spark DataFrame +#' @param y A Spark DataFrame +#' @return A DataFrame containing the result of the subtract operation. +#' @rdname subtract +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlCtx, path) +#' df2 <- jsonFile(sqlCtx, path2) +#' subtractDF <- subtract(df, df2) +#' } +setMethod("subtract", + signature(x = "DataFrame", y = "DataFrame"), + function(x, y) { + subtracted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(subtracted) + }) + +#' Save the contents of the DataFrame to a data source +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param path A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] = path + } + callJMethod(df@sdf, "save", source, jmode, options) + }) + + +#' saveAsTable +#' +#' Save the contents of the DataFrame to a data source as a table +#' +#' The data source is specified by the `source` and a set of options (...). +#' If `source` is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: +#' append: Contents of this DataFrame are expected to be appended to existing data. +#' overwrite: Existing data is expected to be overwritten by the contents of +# this DataFrame. +#' error: An exception is expected to be thrown. +#' ignore: The save operation is expected to not save the contents of the DataFrame +# and to not change the existing data. +#' +#' @param df A SparkSQL DataFrame +#' @param tableName A name for the table +#' @param source A name for external data source +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' +#' @rdname saveAsTable +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' saveAsTable(df, "myfile") +#' } +setMethod("saveAsTable", + signature(df = "DataFrame", tableName = 'character', source = 'character', + mode = 'character'), + function(df, tableName, source = NULL, mode="append", ...){ + if (is.null(source)) { + sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + options <- varargsToEnv(...) + callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + }) + diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R new file mode 100644 index 0000000000..604ad03c40 --- /dev/null +++ b/R/pkg/R/RDD.R @@ -0,0 +1,1539 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# RDD in R implemented in S4 OO system. + +setOldClass("jobj") + +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @export +setClass("RDD", + slots = list(env = "environment", + jrdd = "jobj")) + +setClass("PipelinedRDD", + slots = list(prev = "RDD", + func = "function", + prev_jrdd = "jobj"), + contains = "RDD") + +setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, + isCached, isCheckpointed) { + # Check that RDD constructor is using the correct version of serializedMode + stopifnot(class(serializedMode) == "character") + stopifnot(serializedMode %in% c("byte", "string", "row")) + # RDD has three serialization types: + # byte: The RDD stores data serialized in R. + # string: The RDD stores data as strings. + # row: The RDD stores the serialized rows of a DataFrame. + + # We use an environment to store mutable states inside an RDD object. + # Note that R's call-by-value semantics makes modifying slots inside an + # object (passed as an argument into a function, such as cache()) difficult: + # i.e. one needs to make a copy of the RDD object and sets the new slot value + # there. + + # The slots are inheritable from superclass. Here, both `env' and `jrdd' are + # inherited from RDD, but only the former is used. + .Object@env <- new.env() + .Object@env$isCached <- isCached + .Object@env$isCheckpointed <- isCheckpointed + .Object@env$serializedMode <- serializedMode + + .Object@jrdd <- jrdd + .Object +}) + +setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { + .Object@env <- new.env() + .Object@env$isCached <- FALSE + .Object@env$isCheckpointed <- FALSE + .Object@env$jrdd_val <- jrdd_val + if (!is.null(jrdd_val)) { + # This tracks the serialization mode for jrdd_val + .Object@env$serializedMode <- prev@env$serializedMode + } + + .Object@prev <- prev + + isPipelinable <- function(rdd) { + e <- rdd@env + !(e$isCached || e$isCheckpointed) + } + + if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { + # This transformation is the first in its stage: + .Object@func <- func + .Object@prev_jrdd <- getJRDD(prev) + .Object@env$prev_serializedMode <- prev@env$serializedMode + # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD + # prev_serializedMode is used during the delayed computation of JRDD in getJRDD + } else { + pipelinedFunc <- function(split, iterator) { + func(split, prev@func(split, iterator)) + } + .Object@func <- pipelinedFunc + .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline + # Get the serialization mode of the parent RDD + .Object@env$prev_serializedMode <- prev@env$prev_serializedMode + } + + .Object +}) + +#' @rdname RDD +#' @export +#' +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed +RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, + isCheckpointed = FALSE) { + new("RDD", jrdd, serializedMode, isCached, isCheckpointed) +} + +PipelinedRDD <- function(prev, func) { + new("PipelinedRDD", prev, func, NULL) +} + +# Return the serialization mode for an RDD. +setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) +# For normal RDDs we can directly read the serializedMode +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +# For pipelined RDDs if jrdd_val is set then serializedMode should exist +# if not we return the defaultSerialization mode of "byte" as we don't know the serialization +# mode at this point in time. +setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), + function(rdd) { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$serializedMode) + } else { + return("byte") + } + }) + +# The jrdd accessor function. +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "PipelinedRDD"), + function(rdd, serializedMode = "byte") { + if (!is.null(rdd@env$jrdd_val)) { + return(rdd@env$jrdd_val) + } + + computeFunc <- function(split, part) { + rdd@func(split, part) + } + + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + serializedFuncArr <- serialize(computeFunc, connection = NULL) + + prev_jrdd <- rdd@prev_jrdd + + if (serializedMode == "string") { + rddRef <- newJObject("org.apache.spark.api.r.StringRRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } else { + rddRef <- newJObject("org.apache.spark.api.r.RRDD", + callJMethod(prev_jrdd, "rdd"), + serializedFuncArr, + rdd@env$prev_serializedMode, + serializedMode, + packageNamesArr, + as.character(.sparkREnv[["libname"]]), + broadcastArr, + callJMethod(prev_jrdd, "classTag")) + } + # Save the serialization flag after we create a RRDD + rdd@env$serializedMode <- serializedMode + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val + }) + +setValidity("RDD", + function(object) { + jrdd <- getJRDD(object) + cls <- callJMethod(jrdd, "getClass") + className <- callJMethod(cls, "getName") + if (grep("spark.api.java.*RDD*", className) == 1) { + TRUE + } else { + paste("Invalid RDD class ", className) + } + }) + + +############ Actions and Transformations ############ + +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +setMethod("cache", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "cache") + x@env$isCached <- TRUE + x + }) + +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +setMethod("persist", + signature(x = "RDD", newLevel = "character"), + function(x, newLevel) { + callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) + x@env$isCached <- TRUE + x + }) + +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +setMethod("unpersist", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "unpersist") + x@env$isCached <- FALSE + x + }) + +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoints") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +setMethod("checkpoint", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + callJMethod(jrdd, "checkpoint") + x@env$isCheckpointed <- TRUE + x + }) + +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +setMethod("numPartitions", + signature(x = "RDD"), + function(x) { + jrdd <- getJRDD(x) + partitions <- callJMethod(jrdd, "splits") + callJMethod(partitions, "size") + }) + +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +setMethod("collect", + signature(x = "RDD"), + function(x, flatten = TRUE) { + # Assumes a pairwise RDD is backed by a JavaPairRDD. + collected <- callJMethod(getJRDD(x), "collect") + convertJListToRList(collected, flatten, + serializedMode = getSerializedMode(x)) + }) + + +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +setMethod("collectPartition", + signature(x = "RDD", partitionId = "integer"), + function(x, partitionId) { + jPartitionsList <- callJMethod(getJRDD(x), + "collectPartitions", + as.list(as.integer(partitionId))) + + jList <- jPartitionsList[[1]] + convertJListToRList(jList, flatten = TRUE, + serializedMode = getSerializedMode(x)) + }) + +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +setMethod("collectAsMap", + signature(x = "RDD"), + function(x) { + pairList <- collect(x) + map <- new.env() + lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) + as.list(map) + }) + +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +setMethod("count", + signature(x = "RDD"), + function(x) { + countPartition <- function(part) { + as.integer(length(part)) + } + valsRDD <- lapplyPartition(x, countPartition) + vals <- collect(valsRDD) + sum(as.integer(vals)) + }) + +#' Return the number of elements in the RDD +#' @export +#' @rdname count +setMethod("length", + signature(x = "RDD"), + function(x) { + count(x) + }) + +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +setMethod("countByValue", + signature(x = "RDD"), + function(x) { + ones <- lapply(x, function(item) { list(item, 1L) }) + collect(reduceByKey(ones, `+`, numPartitions(x))) + }) + +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} +setMethod("lapply", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(split, iterator) { + lapply(iterator, FUN) + } + lapplyPartitionsWithIndex(X, func) + }) + +#' @rdname lapply +#' @aliases map,RDD,function-method +setMethod("map", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapply(X, FUN) + }) + +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +setMethod("flatMap", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + partitionFunc <- function(part) { + unlist( + lapply(part, FUN), + recursive = F + ) + } + lapplyPartition(X, partitionFunc) + }) + +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +setMethod("lapplyPartition", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) + }) + +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +setMethod("mapPartitions", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartition(X, FUN) + }) + +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { +#' split * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +setMethod("lapplyPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + FUN <- cleanClosure(FUN) + closureCapturingFunc <- function(split, part) { + FUN(split, part) + } + PipelinedRDD(X, closureCapturingFunc) + }) + +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +setMethod("mapPartitionsWithIndex", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + lapplyPartitionsWithIndex(X, FUN) + }) + +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +setMethod("filterRDD", + signature(x = "RDD", f = "function"), + function(x, f) { + filter.func <- function(part) { + Filter(f, part) + } + lapplyPartition(x, filter.func) + }) + +#' @rdname filterRDD +#' @aliases Filter +setMethod("Filter", + signature(f = "function", x = "RDD"), + function(f, x) { + filterRDD(x, f) + }) + +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +setMethod("reduce", + signature(x = "RDD", func = "ANY"), + function(x, func) { + + reducePartition <- function(part) { + Reduce(func, part) + } + + partitionList <- collect(lapplyPartition(x, reducePartition), + flatten = FALSE) + Reduce(func, partitionList) + }) + +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +setMethod("maximum", + signature(x = "RDD"), + function(x) { + reduce(x, max) + }) + +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +setMethod("minimum", + signature(x = "RDD"), + function(x) { + reduce(x, min) + }) + +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +setMethod("sumRDD", + signature(x = "RDD"), + function(x) { + reduce(x, "+") + }) + +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +setMethod("foreach", + signature(x = "RDD", func = "function"), + function(x, func) { + partition.func <- function(x) { + lapply(x, func) + NULL + } + invisible(collect(mapPartitions(x, partition.func))) + }) + +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +setMethod("foreachPartition", + signature(x = "RDD", func = "function"), + function(x, func) { + invisible(collect(mapPartitions(x, func))) + }) + +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +setMethod("take", + signature(x = "RDD", num = "numeric"), + function(x, num) { + resList <- list() + index <- -1 + jrdd <- getJRDD(x) + numPartitions <- numPartitions(x) + + # TODO(shivaram): Collect more than one partition based on size + # estimates similar to the scala version of `take`. + while (TRUE) { + index <- index + 1 + + if (length(resList) >= num || index >= numPartitions) + break + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + size <- num - length(resList) + # elems is capped to have at most `size` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = size, + serializedMode = getSerializedMode(x)) + # TODO: Check if this append is O(n^2)? + resList <- append(resList, elems) + } + resList + }) + +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +setMethod("first", + signature(x = "RDD"), + function(x) { + take(x, 1)[[1]] + }) + +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +setMethod("distinct", + signature(x = "RDD"), + function(x, numPartitions = SparkR::numPartitions(x)) { + identical.mapped <- lapply(x, function(x) { list(x, NULL) }) + reduced <- reduceByKey(identical.mapped, + function(x, y) { x }, + numPartitions) + resRDD <- lapply(reduced, function(x) { x[[1]] }) + resRDD + }) + +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +setMethod("sampleRDD", + signature(x = "RDD", withReplacement = "logical", + fraction = "numeric", seed = "integer"), + function(x, withReplacement, fraction, seed) { + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(seed) + res <- vector("list", length(part)) + len <- 0 + + # Discards some random values to ensure each partition has a + # different random seed. + runif(split) + + for (elem in part) { + if (withReplacement) { + count <- rpois(1, fraction) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < fraction) { + len <- len + 1 + res[[len]] <- elem + } + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. + if (len > 0) + res[1:len] + else + list() + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) + +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", + num = "integer", seed = "integer"), + function(x, withReplacement, num, seed) { + # This function is ported from RDD.scala. + fraction <- 0.0 + total <- 0 + multiplier <- 3.0 + initialCount <- count(x) + maxSelected <- 0 + MAXINT <- .Machine$integer.max + + if (num < 0) + stop(paste("Negative number of elements requested")) + + if (initialCount > MAXINT - 1) { + maxSelected <- MAXINT - 1 + } else { + maxSelected <- initialCount + } + + if (num > initialCount && !withReplacement) { + total <- maxSelected + fraction <- multiplier * (maxSelected + 1) / initialCount + } else { + total <- num + fraction <- multiplier * (num + 1) / initialCount + } + + set.seed(seed) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + # If the first sample didn't turn out large enough, keep trying to + # take samples; this shouldn't happen often because we use a big + # multiplier for thei initial size + while (length(samples) < total) + samples <- collect(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(runif(1, + -MAXINT, + MAXINT))))) + + # TODO(zongheng): investigate if this call is an in-place shuffle? + sample(samples)[1:total] + }) + +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +setMethod("keyBy", + signature(x = "RDD", func = "function"), + function(x, func) { + apply.func <- function(x) { + list(func(x), x) + } + lapply(x, apply.func) + }) + +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +setMethod("repartition", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions) { + coalesce(x, numToInt(numPartitions), TRUE) + }) + +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +setMethod("coalesce", + signature(x = "RDD", numPartitions = "numeric"), + function(x, numPartitions, shuffle = FALSE) { + numPartitions <- numToInt(numPartitions) + if (shuffle || numPartitions > SparkR::numPartitions(x)) { + func <- function(s, part) { + set.seed(s) # split as seed + start <- as.integer(sample(numPartitions, 1) - 1) + lapply(seq_along(part), + function(i) { + pos <- (start + i) %% numPartitions + list(pos, part[[i]]) + }) + } + shuffled <- lapplyPartitionsWithIndex(x, func) + repartitioned <- partitionBy(shuffled, numPartitions) + values(repartitioned) + } else { + jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) + RDD(jrdd) + } + }) + +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +setMethod("saveAsObjectFile", + signature(x = "RDD", path = "character"), + function(x, path) { + # If serializedMode == "string" we need to serialize the data before saving it since + # objectFile() assumes serializedMode == "byte". + if (getSerializedMode(x) != "byte") { + x <- serializeToBytes(x) + } + # Return nothing + invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) + }) + +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the splits of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +setMethod("saveAsTextFile", + signature(x = "RDD", path = "character"), + function(x, path) { + func <- function(str) { + toString(str) + } + stringRdd <- lapply(x, func) + # Return nothing + invisible( + callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) + }) + +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(x = "RDD", func = "function"), + function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + values(sortByKey(keyBy(x, func), ascending, numPartitions)) + }) + +# Helper function to get first N elements from an RDD in the specified order. +# Param: +# x An RDD. +# num Number of elements to return. +# ascending A flag to indicate whether the sorting is ascending or descending. +# Return: +# A list of the first N elements from the RDD in the specified order. +# +takeOrderedElem <- function(x, num, ascending = TRUE) { + if (num <= 0L) { + return(list()) + } + + partitionFunc <- function(part) { + if (num < length(part)) { + # R limitation: order works only on primitive types! + ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) + list(part[ord[1:num]]) + } else { + list(part) + } + } + + reduceFunc <- function(elems, part) { + newElems <- append(elems, part) + # R limitation: order works only on primitive types! + ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) + newElems[ord[1:num]] + } + + newRdd <- mapPartitions(x, partitionFunc) + reduce(newRdd, reduceFunc) +} + +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +setMethod("takeOrdered", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num) + }) + +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @rdname top +#' @aliases top,RDD,RDD-method +setMethod("top", + signature(x = "RDD", num = "integer"), + function(x, num) { + takeOrderedElem(x, num, FALSE) + }) + +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @rdname fold +#' @aliases fold,RDD,RDD-method +setMethod("fold", + signature(x = "RDD", zeroValue = "ANY", op = "ANY"), + function(x, zeroValue, op) { + aggregateRDD(x, zeroValue, op, op) + }) + +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @rdname aggregateRDD +#' @aliases aggregateRDD,RDD,RDD-method +setMethod("aggregateRDD", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), + function(x, zeroValue, seqOp, combOp) { + partitionFunc <- function(part) { + Reduce(seqOp, part, zeroValue) + } + + partitionList <- collect(lapplyPartition(x, partitionFunc), + flatten = FALSE) + Reduce(combOp, partitionList, zeroValue) + }) + +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @rdname pipeRDD +#' @aliases pipeRDD,RDD,character-method +setMethod("pipeRDD", + signature(x = "RDD", command = "character"), + function(x, command, env = list()) { + func <- function(part) { + trim.trailing.func <- function(x) { + sub("[\r\n]*$", "", toString(x)) + } + input <- unlist(lapply(part, trim.trailing.func)) + res <- system2(command, stdout = TRUE, input = input, env = env) + lapply(res, trim.trailing.func) + } + lapplyPartition(x, func) + }) + +# TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @rdname name +#' @aliases name,RDD +setMethod("name", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "name") + }) + +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @rdname setName +#' @aliases setName,RDD +setMethod("setName", + signature(x = "RDD", name = "character"), + function(x, name) { + callJMethod(getJRDD(x), "setName", name) + x + }) + +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +setMethod("zipWithUniqueId", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + + partitionFunc <- function(split, part) { + mapply( + function(item, index) { + list(item, (index - 1) * n + split) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +setMethod("zipWithIndex", + signature(x = "RDD"), + function(x) { + n <- numPartitions(x) + if (n > 1) { + nums <- collect(lapplyPartition(x, + function(part) { + list(length(part)) + })) + startIndices <- Reduce("+", nums, accumulate = TRUE) + } + + partitionFunc <- function(split, part) { + if (split == 0) { + startIndex <- 0 + } else { + startIndex <- startIndices[[split]] + } + + mapply( + function(item, index) { + list(item, index - 1 + startIndex) + }, + part, + seq_along(part), + SIMPLIFY = FALSE) + } + + lapplyPartitionsWithIndex(x, partitionFunc) + }) + +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +setMethod("glom", + signature(x = "RDD"), + function(x) { + partitionFunc <- function(part) { + list(part) + } + + lapplyPartition(x, partitionFunc) + }) + +############ Binary Functions ############# + +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +setMethod("unionRDD", + signature(x = "RDD", y = "RDD"), + function(x, y) { + if (getSerializedMode(x) == getSerializedMode(y)) { + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, getSerializedMode(x)) + } else { + # One of the RDDs is not serialized, we need to serialize it first. + if (getSerializedMode(x) != "byte") x <- serializeToBytes(x) + if (getSerializedMode(y) != "byte") y <- serializeToBytes(y) + jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y)) + union.rdd <- RDD(jrdd, "byte") + } + union.rdd + }) + +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +setMethod("zipRDD", + signature(x = "RDD", other = "RDD"), + function(x, other) { + n1 <- numPartitions(x) + n2 <- numPartitions(other) + if (n1 != n2) { + stop("Can only zip RDDs which have the same number of partitions.") + } + + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # check if corresponding partitions of both RDDs have the same number of elements. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + part[[length(part) + 1]] <- length(part) + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + + zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) + # The zippedRDD's elements are of scala Tuple2 type. The serialized + # flag Here is used for the elements inside the tuples. + serializerMode <- getSerializedMode(x) + zippedRDD <- RDD(zippedJRDD, serializerMode) + + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # check if corresponding partitions of both RDDs have the same number of elements. + if (lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + keys <- list() + values <- list() + } + } else { + # Keys, values must have same length here, because this has + # been validated inside the JavaRDD.zip() function. + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { + list(k, v) + }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(zippedRDD, partitionFunc) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R new file mode 100644 index 0000000000..930ada22f4 --- /dev/null +++ b/R/pkg/R/SQLContext.R @@ -0,0 +1,520 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# SQLcontext.R: SQLContext-driven functions + +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + # SPAKR-SQL does not support '.' in column name, so replace it with '_' + # TODO(davies): remove this once SPARK-2775 is fixed + names <- lapply(names, function(n) { + nn <- gsub("[.]", "_", n) + if (nn != n) { + warning(paste("Use", nn, "instead of", n, " as column name")) + } + nn + }) + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + +#' Create a DataFrame from a JSON file. +#' +#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' It goes through the entire dataset once to determine the schema. +#' +#' @param sqlCtx SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' } + +jsonFile <- function(sqlCtx, path) { + # Allow the user to have a more flexible definiton of the text file path + path <- normalizePath(path) + # Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + sdf <- callJMethod(sqlCtx, "jsonFile", path) + dataFrame(sdf) +} + + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlCtx, rdd) +#' } + +# TODO: support schema +jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { + rdd <- serializeToString(rdd) + if (is.null(schema)) { + sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + dataFrame(sdf) + } else { + stop("not implemented") + } +} + + +#' Create a DataFrame from a Parquet file. +#' +#' Loads a Parquet file, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param ... Path(s) of parquet file(s) to read. +#' @return DataFrame +#' @export + +# TODO: Implement saveasParquetFile and write examples for both +parquetFile <- function(sqlCtx, ...) { + # Allow the user to have a more flexible definiton of the text file path + paths <- lapply(list(...), normalizePath) + sdf <- callJMethod(sqlCtx, "parquetFile", paths) + dataFrame(sdf) +} + +#' SQL Query +#' +#' Executes a SQL query using Spark, returning the result as a DataFrame. +#' +#' @param sqlCtx SQLContext to use +#' @param sqlQuery A character vector containing the SQL query +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' } + +sql <- function(sqlCtx, sqlQuery) { + sdf <- callJMethod(sqlCtx, "sql", sqlQuery) + dataFrame(sdf) +} + + +#' Create a DataFrame from a SparkSQL Table +#' +#' Returns the specified Table as a DataFrame. The Table must have already been registered +#' in the SQLContext. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The SparkSQL Table to convert to a DataFrame. +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' new_df <- table(sqlCtx, "table") +#' } + +table <- function(sqlCtx, tableName) { + sdf <- callJMethod(sqlCtx, "table", tableName) + dataFrame(sdf) +} + + +#' Tables +#' +#' Returns a DataFrame containing names of tables in the given database. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tables(sqlCtx, "hive") +#' } + +tables <- function(sqlCtx, databaseName = NULL) { + jdf <- if (is.null(databaseName)) { + callJMethod(sqlCtx, "tables") + } else { + callJMethod(sqlCtx, "tables", databaseName) + } + dataFrame(jdf) +} + + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param sqlCtx SQLContext to use +#' @param databaseName name of the database +#' @return a list of table names +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' tableNames(sqlCtx, "hive") +#' } + +tableNames <- function(sqlCtx, databaseName = NULL) { + if (is.null(databaseName)) { + callJMethod(sqlCtx, "tableNames") + } else { + callJMethod(sqlCtx, "tableNames", databaseName) + } +} + + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being cached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' cacheTable(sqlCtx, "table") +#' } + +cacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "cacheTable", tableName) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the table being uncached +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' registerTempTable(df, "table") +#' uncacheTable(sqlCtx, "table") +#' } + +uncacheTable <- function(sqlCtx, tableName) { + callJMethod(sqlCtx, "uncacheTable", tableName) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @param sqlCtx SQLContext to use +#' @examples +#' \dontrun{ +#' clearCache(sqlCtx) +#' } + +clearCache <- function(sqlCtx) { + callJMethod(sqlCtx, "clearCache") +} + +#' Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName The name of the SparkSQL table to be dropped. +#' @examples +#' \dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- loadDF(sqlCtx, path, "parquet") +#' registerTempTable(df, "table") +#' dropTempTable(sqlCtx, "table") +#' } + +dropTempTable <- function(sqlCtx, tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + callJMethod(sqlCtx, "dropTempTable", tableName) +} + +#' Load an DataFrame +#' +#' Returns the dataset in a data source as a DataFrame +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' } + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "load", source, options) + dataFrame(sdf) +} + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns the DataFrame associated with the external table. +#' +#' The data source is specified by the `source` and a set of options(...). +#' If `source` is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param sqlCtx SQLContext to use +#' @param tableName A name of the table +#' @param path The path of files to load +#' @param source the name of external data source +#' @return DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' } + +createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { + options <- varargsToEnv(...) + if (!is.null(path)) { + options[['path']] <- path + } + sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + dataFrame(sdf) +} diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R new file mode 100644 index 0000000000..962fba5b3c --- /dev/null +++ b/R/pkg/R/SQLTypes.R @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions for handling SparkSQL DataTypes. + +# Handler for StructType +structType <- function(st) { + obj <- structure(new.env(parent = emptyenv()), class = "structType") + obj$jobj <- st + obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } + obj +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + fieldsList <- lapply(x$fields(), function(i) { i$print() }) + print(fieldsList) +} + +# Handler for StructField +structField <- function(sf) { + obj <- structure(new.env(parent = emptyenv()), class = "structField") + obj$jobj <- sf + obj$name <- function() { callJMethod(sf, "name") } + obj$dataType <- function() { callJMethod(sf, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(sf, "nullable") } + obj$print <- function() { paste("StructField(", + paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), + ")", sep = "") } + obj +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat(x$print()) +} diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R new file mode 100644 index 0000000000..2fb6fae55f --- /dev/null +++ b/R/pkg/R/backend.R @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to call into SparkRBackend. + + +# Returns TRUE if object is an instance of given class +isInstanceOf <- function(jobj, className) { + stopifnot(class(jobj) == "jobj") + cls <- callJStatic("java.lang.Class", "forName", className) + callJMethod(cls, "isInstance", jobj) +} + +# Call a Java method named methodName on the object +# specified by objId. objId should be a "jobj" returned +# from the SparkRBackend. +callJMethod <- function(objId, methodName, ...) { + stopifnot(class(objId) == "jobj") + if (!isValidJobj(objId)) { + stop("Invalid jobj ", objId$id, + ". If SparkR was restarted, Spark operations need to be re-executed.") + } + invokeJava(isStatic = FALSE, objId$id, methodName, ...) +} + +# Call a static method on a specified className +callJStatic <- function(className, methodName, ...) { + invokeJava(isStatic = TRUE, className, methodName, ...) +} + +# Create a new object of the specified class name +newJObject <- function(className, ...) { + invokeJava(isStatic = TRUE, className, methodName = "", ...) +} + +# Remove an object from the SparkR backend. This is done +# automatically when a jobj is garbage collected. +removeJObject <- function(objId) { + invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId) +} + +isRemoveMethod <- function(isStatic, objId, methodName) { + isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm" +} + +# Invoke a Java method on the SparkR backend. Users +# should typically use one of the higher level methods like +# callJMethod, callJStatic etc. instead of using this. +# +# isStatic - TRUE if the method to be called is static +# objId - String that refers to the object on which method is invoked +# Should be a jobj id for non-static methods and the classname +# for static methods +# methodName - name of method to be invoked +invokeJava <- function(isStatic, objId, methodName, ...) { + if (!exists(".sparkRCon", .sparkREnv)) { + stop("No connection to backend found. Please re-run sparkR.init") + } + + # If this isn't a removeJObject call + if (!isRemoveMethod(isStatic, objId, methodName)) { + objsToRemove <- ls(.toRemoveJobjs) + if (length(objsToRemove) > 0) { + sapply(objsToRemove, + function(e) { + removeJObject(e) + }) + rm(list = objsToRemove, envir = .toRemoveJobjs) + } + } + + + rc <- rawConnection(raw(0), "r+") + + writeBoolean(rc, isStatic) + writeString(rc, objId) + writeString(rc, methodName) + + args <- list(...) + writeInt(rc, length(args)) + writeArgs(rc, args) + + # Construct the whole request message to send it once, + # avoiding write-write-read pattern in case of Nagle's algorithm. + # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details. + bytesToSend <- rawConnectionValue(rc) + close(rc) + rc <- rawConnection(raw(0), "r+") + writeInt(rc, length(bytesToSend)) + writeBin(bytesToSend, rc) + requestMessage <- rawConnectionValue(rc) + close(rc) + + conn <- get(".sparkRCon", .sparkREnv) + writeBin(requestMessage, conn) + + # TODO: check the status code to output error information + returnStatus <- readInt(conn) + stopifnot(returnStatus == 0) + readObject(conn) +} diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R new file mode 100644 index 0000000000..583fa2e7fd --- /dev/null +++ b/R/pkg/R/broadcast.R @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# S4 class representing Broadcast variables + +# Hidden environment that holds values for broadcast variables +# This will not be serialized / shipped by default +.broadcastNames <- new.env() +.broadcastValues <- new.env() +.broadcastIdToName <- new.env() + +#' @title S4 class that represents a Broadcast variable +#' @description Broadcast variables can be created using the broadcast +#' function from a \code{SparkContext}. +#' @rdname broadcast-class +#' @seealso broadcast +#' +#' @param id Id of the backing Spark broadcast variable +#' @export +setClass("Broadcast", slots = list(id = "character")) + +#' @rdname broadcast-class +#' @param value Value of the broadcast variable +#' @param jBroadcastRef reference to the backing Java broadcast object +#' @param objName name of broadcasted object +#' @export +Broadcast <- function(id, value, jBroadcastRef, objName) { + .broadcastValues[[id]] <- value + .broadcastNames[[as.character(objName)]] <- jBroadcastRef + .broadcastIdToName[[id]] <- as.character(objName) + new("Broadcast", id = id) +} + +#' @description +#' \code{value} can be used to get the value of a broadcast variable inside +#' a distributed function. +#' +#' @param bcast The broadcast variable to get +#' @rdname broadcast +#' @aliases value,Broadcast-method +setMethod("value", + signature(bcast = "Broadcast"), + function(bcast) { + if (exists(bcast@id, envir = .broadcastValues)) { + get(bcast@id, envir = .broadcastValues) + } else { + NULL + } + }) + +#' Internal function to set values of a broadcast variable. +#' +#' This function is used internally by Spark to set the value of a broadcast +#' variable on workers. Not intended for use outside the package. +#' +#' @rdname broadcast-internal +#' @seealso broadcast, value + +#' @param bcastId The id of broadcast variable to set +#' @param value The value to be set +#' @export +setBroadcastValue <- function(bcastId, value) { + bcastIdStr <- as.character(bcastId) + .broadcastValues[[bcastIdStr]] <- value +} + +#' Helper function to clear the list of broadcast variables we know about +#' Should be called when the SparkR JVM backend is shutdown +clearBroadcastVariables <- function() { + bcasts <- ls(.broadcastNames) + rm(list = bcasts, envir = .broadcastNames) +} diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R new file mode 100644 index 0000000000..1281c41213 --- /dev/null +++ b/R/pkg/R/client.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Client code to connect to SparkRBackend + +# Creates a SparkR client connection object +# if one doesn't already exist +connectBackend <- function(hostname, port, timeout = 6000) { + if (exists(".sparkRcon", envir = .sparkREnv)) { + if (isOpen(.sparkREnv[[".sparkRCon"]])) { + cat("SparkRBackend client connection already exists\n") + return(get(".sparkRcon", envir = .sparkREnv)) + } + } + + con <- socketConnection(host = hostname, port = port, server = FALSE, + blocking = TRUE, open = "wb", timeout = timeout) + + assign(".sparkRCon", con, envir = .sparkREnv) + con +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { + if (.Platform$OS.type == "unix") { + sparkSubmitBinName = "spark-submit" + } else { + sparkSubmitBinName = "spark-submit.cmd" + } + + if (sparkHome != "") { + sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) + } else { + sparkSubmitBin <- sparkSubmitBinName + } + + if (jars != "") { + jars <- paste("--jars", jars) + } + + combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") + invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R new file mode 100644 index 0000000000..e196305186 --- /dev/null +++ b/R/pkg/R/column.R @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Column Class + +#' @include generics.R jobj.R +NULL + +setOldClass("jobj") + +#' @title S4 class that represents a DataFrame column +#' @description The column class supports unary, binary operations on DataFrame columns + +#' @rdname column +#' +#' @param jc reference to JVM DataFrame column +#' @export +setClass("Column", + slots = list(jc = "jobj")) + +setMethod("initialize", "Column", function(.Object, jc) { + .Object@jc <- jc + .Object +}) + +column <- function(jc) { + new("Column", jc) +} + +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' @rdname show +setMethod("show", "Column", + function(object) { + cat("Column", callJMethod(object@jc, "toString"), "\n") + }) + +operators <- list( + "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", + "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", + # we can not override `&&` and `||`, so use `&` and `|` instead + "&" = "and", "|" = "or" #, "!" = "unary_$bang" +) +column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", + "first", "last", "lower", "upper", "sumDistinct") + +createOperator <- function(op) { + setMethod(op, + signature(e1 = "Column"), + function(e1, e2) { + jc <- if (missing(e2)) { + if (op == "-") { + callJMethod(e1@jc, "unary_$minus") + } else { + callJMethod(e1@jc, operators[[op]]) + } + } else { + if (class(e2) == "Column") { + e2 <- e2@jc + } + callJMethod(e1@jc, operators[[op]], e2) + } + column(jc) + }) +} + +createColumnFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + column(callJMethod(x@jc, name)) + }) +} + +createColumnFunction2 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x, data) { + if (class(data) == "Column") { + data <- data@jc + } + jc <- callJMethod(x@jc, name, data) + column(jc) + }) +} + +createStaticFunction <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createMethods <- function() { + for (op in names(operators)) { + createOperator(op) + } + for (name in column_functions1) { + createColumnFunction1(name) + } + for (name in column_functions2) { + createColumnFunction2(name) + } + for (x in functions) { + createStaticFunction(x) + } +} + +createMethods() + +#' alias +#' +#' Set a new name for a column +setMethod("alias", + signature(object = "Column"), + function(object, data) { + if (is.character(data)) { + column(callJMethod(object@jc, "as", data)) + } else { + stop("data should be character") + } + }) + +#' An expression that returns a substring. +#' +#' @param start starting position +#' @param stop ending position +setMethod("substr", signature(x = "Column"), + function(x, start, stop) { + jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + column(jc) + }) + +#' Casts the column to a different data type. +#' @examples +#' \dontrun{ +#' cast(df$age, "string") +#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) +#' } +setMethod("cast", + signature(x = "Column"), + function(x, dataType) { + if (is.character(dataType)) { + column(callJMethod(x@jc, "cast", dataType)) + } else if (is.list(dataType)) { + json <- tojson(dataType) + jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) + column(callJMethod(x@jc, "cast", jdataType)) + } else { + stop("dataType should be character or list") + } + }) + +#' Approx Count Distinct +#' +#' Returns the approximate number of distinct items in a group. +#' +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' returns the number of distinct items in a group. +#' +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R new file mode 100644 index 0000000000..2fc0bb294b --- /dev/null +++ b/R/pkg/R/context.R @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# context.R: SparkContext driven functions + +getMinSplits <- function(sc, minSplits) { + if (is.null(minSplits)) { + defaultParallelism <- callJMethod(sc, "defaultParallelism") + minSplits <- min(defaultParallelism, 2) + } + as.integer(minSplits) +} + +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} +textFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + # jrdd is of type JavaRDD[String] + RDD(jrdd, "string") +} + +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} +objectFile <- function(sc, path, minSplits = NULL) { + # Allow the user to have a more flexible definiton of the text file path + path <- suppressWarnings(normalizePath(path)) + #' Convert a string vector of paths to a string containing comma separated paths + path <- paste(path, collapse = ",") + + jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + # Assume the RDD contains serialized R objects. + RDD(jrdd, "byte") +} + +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} +parallelize <- function(sc, coll, numSlices = 1) { + # TODO: bound/safeguard numSlices + # TODO: unit tests for if the split works for all primitives + # TODO: support matrix, data frame, etc + if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + if (is.data.frame(coll)) { + message(paste("context.R: A data frame is parallelized by columns.")) + } else { + if (is.matrix(coll)) { + message(paste("context.R: A matrix is parallelized by elements.")) + } else { + message(paste("context.R: parallelize() currently only supports lists and vectors.", + "Calling as.list() to coerce coll into a list.")) + } + } + coll <- as.list(coll) + } + + if (numSlices > length(coll)) + numSlices <- length(coll) + + sliceLen <- ceiling(length(coll) / numSlices) + slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + + # Serialize each slice: obtain a list of raws, or a list of lists (slices) of + # 2-tuples of raws + serializedSlices <- lapply(slices, serialize, connection = NULL) + + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", + "createRDDFromArray", sc, serializedSlices) + + RDD(jrdd, "byte") +} + +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' +#' @export +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} +includePackage <- function(sc, pkg) { + pkg <- as.character(substitute(pkg)) + if (exists(".packages", .sparkREnv)) { + packages <- .sparkREnv$.packages + } else { + packages <- list() + } + packages <- c(packages, pkg) + .sparkREnv$.packages <- packages +} + +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} +broadcast <- function(sc, object) { + objName <- as.character(substitute(object)) + serializedObj <- serialize(object, connection = NULL) + + jBroadcast <- callJMethod(sc, "broadcast", serializedObj) + id <- as.character(callJMethod(jBroadcast, "id")) + + Broadcast(id, object, jBroadcast, objName) +} + +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoints") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} +setCheckpointDir <- function(sc, dirName) { + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) +} diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R new file mode 100644 index 0000000000..257b435607 --- /dev/null +++ b/R/pkg/R/deserialize.R @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to deserialize objects from Java. + +# Type mapping from Java to R +# +# void -> NULL +# Int -> integer +# String -> character +# Boolean -> logical +# Double -> double +# Long -> double +# Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct +# +# Array[T] -> list() +# Object -> jobj + +readObject <- function(con) { + # Read type first + type <- readType(con) + readTypedObject(con, type) +} + +readTypedObject <- function(con, type) { + switch (type, + "i" = readInt(con), + "c" = readString(con), + "b" = readBoolean(con), + "d" = readDouble(con), + "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), + "l" = readList(con), + "n" = NULL, + "j" = getJobj(readString(con)), + stop(paste("Unsupported type for deserialization", type))) +} + +readString <- function(con) { + stringLen <- readInt(con) + string <- readBin(con, raw(), stringLen, endian = "big") + rawToChar(string) +} + +readInt <- function(con) { + readBin(con, integer(), n = 1, endian = "big") +} + +readDouble <- function(con) { + readBin(con, double(), n = 1, endian = "big") +} + +readBoolean <- function(con) { + as.logical(readInt(con)) +} + +readType <- function(con) { + rawToChar(readBin(con, "raw", n = 1L)) +} + +readDate <- function(con) { + as.Date(readString(con)) +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + +# We only support lists where all elements are of same type +readList <- function(con) { + type <- readType(con) + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + l[[i]] <- readTypedObject(con, type) + } + l + } else { + list() + } +} + +readRaw <- function(con) { + dataLen <- readInt(con) + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readRawLen <- function(con, dataLen) { + data <- readBin(con, raw(), as.integer(dataLen), endian = "big") +} + +readDeserialize <- function(con) { + # We have two cases that are possible - In one, the entire partition is + # encoded as a byte array, so we have only one value to read. If so just + # return firstData + dataLen <- readInt(con) + firstData <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + + # Else, read things into a list + dataLen <- readInt(con) + if (length(dataLen) > 0 && dataLen > 0) { + data <- list(firstData) + while (length(dataLen) > 0 && dataLen > 0) { + data[[length(data) + 1L]] <- unserialize( + readBin(con, raw(), as.integer(dataLen), endian = "big")) + dataLen <- readInt(con) + } + unlist(data, recursive = FALSE) + } else { + firstData + } +} + +readDeserializeRows <- function(inputCon) { + # readDeserializeRows will deserialize a DataOutputStream composed of + # a list of lists. Since the DOS is one continuous stream and + # the number of rows varies, we put the readRow function in a while loop + # that termintates when the next row is empty. + data <- list() + while(TRUE) { + row <- readRow(inputCon) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row + } + data # this is a list of named lists now +} + +readRowList <- function(obj) { + # readRowList is meant for use inside an lapply. As a result, it is + # necessary to open a standalone connection for the row and consume + # the numCols bytes inside the read function in order to correctly + # deserialize the row. + rawObj <- rawConnection(obj, "r+") + on.exit(close(rawObj)) + readRow(rawObj) +} + +readRow <- function(inputCon) { + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } +} + +# Take a single column as Array[Byte] and deserialize it into an atomic vector +readCol <- function(inputCon, numRows) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R new file mode 100644 index 0000000000..5fb1ccaa84 --- /dev/null +++ b/R/pkg/R/generics.R @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############ RDD Actions and Transformations ############ + +#' @rdname aggregateRDD +#' @seealso reduce +#' @export +setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) + +#' @rdname cache-methods +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname coalesce +#' @seealso repartition +#' @export +setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) + +#' @rdname checkpoint-methods +#' @export +setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) + +#' @rdname collect-methods +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) + +#' @rdname collect-methods +#' @export +setGeneric("collectPartition", + function(x, partitionId) { + standardGeneric("collectPartition") + }) + +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname countByValue +#' @export +setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) + +#' @rdname distinct +#' @export +setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) + +#' @rdname filterRDD +#' @export +setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) + +#' @rdname first +#' @export +setGeneric("first", function(x) { standardGeneric("first") }) + +#' @rdname flatMap +#' @export +setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) + +#' @rdname fold +#' @seealso reduce +#' @export +setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) + +#' @rdname foreach +#' @export +setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) + +#' @rdname foreach +#' @export +setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) + +# The jrdd accessor function. +setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) + +#' @rdname glom +#' @export +setGeneric("glom", function(x) { standardGeneric("glom") }) + +#' @rdname keyBy +#' @export +setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("lapplyPartitionsWithIndex", + function(X, FUN) { + standardGeneric("lapplyPartitionsWithIndex") + }) + +#' @rdname lapply +#' @export +setGeneric("map", function(X, FUN) { standardGeneric("map") }) + +#' @rdname lapplyPartition +#' @export +setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) + +#' @rdname lapplyPartitionsWithIndex +#' @export +setGeneric("mapPartitionsWithIndex", + function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) + +#' @rdname maximum +#' @export +setGeneric("maximum", function(x) { standardGeneric("maximum") }) + +#' @rdname minimum +#' @export +setGeneric("minimum", function(x) { standardGeneric("minimum") }) + +#' @rdname sumRDD +#' @export +setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) + +#' @rdname name +#' @export +setGeneric("name", function(x) { standardGeneric("name") }) + +#' @rdname numPartitions +#' @export +setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) + +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname pipeRDD +#' @export +setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) + +#' @rdname reduce +#' @export +setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) + +#' @rdname repartition +#' @seealso coalesce +#' @export +setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) + +#' @rdname sampleRDD +#' @export +setGeneric("sampleRDD", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleRDD") + }) + +#' @rdname saveAsObjectFile +#' @seealso objectFile +#' @export +setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) + +#' @rdname saveAsTextFile +#' @export +setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) + +#' @rdname setName +#' @export +setGeneric("setName", function(x, name) { standardGeneric("setName") }) + +#' @rdname sortBy +#' @export +setGeneric("sortBy", + function(x, func, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortBy") + }) + +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + +#' @rdname takeOrdered +#' @export +setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) + +#' @rdname takeSample +#' @export +setGeneric("takeSample", + function(x, withReplacement, num, seed) { + standardGeneric("takeSample") + }) + +#' @rdname top +#' @export +setGeneric("top", function(x, num) { standardGeneric("top") }) + +#' @rdname unionRDD +#' @export +setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) + +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + +#' @rdname zipRDD +#' @export +setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) + +#' @rdname zipWithIndex +#' @seealso zipWithUniqueId +#' @export +setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) + +#' @rdname zipWithUniqueId +#' @seealso zipWithIndex +#' @export +setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) + + +############ Binary Functions ############# + +#' @rdname countByKey +#' @export +setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) + +#' @rdname flatMapValues +#' @export +setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) + +#' @rdname keys +#' @export +setGeneric("keys", function(x) { standardGeneric("keys") }) + +#' @rdname lookup +#' @export +setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) + +#' @rdname mapValues +#' @export +setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) + +#' @rdname values +#' @export +setGeneric("values", function(x) { standardGeneric("values") }) + + + +############ Shuffle Functions ############ + +#' @rdname aggregateByKey +#' @seealso foldByKey, combineByKey +#' @export +setGeneric("aggregateByKey", + function(x, zeroValue, seqOp, combOp, numPartitions) { + standardGeneric("aggregateByKey") + }) + +#' @rdname cogroup +#' @export +setGeneric("cogroup", + function(..., numPartitions) { + standardGeneric("cogroup") + }, + signature = "...") + +#' @rdname combineByKey +#' @seealso groupByKey, reduceByKey +#' @export +setGeneric("combineByKey", + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + standardGeneric("combineByKey") + }) + +#' @rdname foldByKey +#' @seealso aggregateByKey, combineByKey +#' @export +setGeneric("foldByKey", + function(x, zeroValue, func, numPartitions) { + standardGeneric("foldByKey") + }) + +#' @rdname join-methods +#' @export +setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) + +#' @rdname groupByKey +#' @seealso reduceByKey +#' @export +setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) + +#' @rdname join-methods +#' @export +setGeneric("join", function(x, y, ...) { standardGeneric("join") }) + +#' @rdname join-methods +#' @export +setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) + +#' @rdname reduceByKey +#' @seealso groupByKey +#' @export +setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) + +#' @rdname reduceByKeyLocally +#' @seealso reduceByKey +#' @export +setGeneric("reduceByKeyLocally", + function(x, combineFunc) { + standardGeneric("reduceByKeyLocally") + }) + +#' @rdname join-methods +#' @export +setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) + +#' @rdname sortByKey +#' @export +setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") +}) + + +################### Broadcast Variable Methods ################# + +#' @rdname broadcast +#' @export +setGeneric("value", function(bcast) { standardGeneric("value") }) + + + +#################### DataFrame Methods ######################## + +#' @rdname schema +#' @export +setGeneric("columns", function(x) {standardGeneric("columns") }) + +#' @rdname schema +#' @export +setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) + +#' @rdname explain +#' @export +setGeneric("explain", function(x, ...) { standardGeneric("explain") }) + +#' @rdname filter +#' @export +setGeneric("filter", function(x, condition) { standardGeneric("filter") }) + +#' @rdname DataFrame +#' @export +setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) + +#' @rdname insertInto +#' @export +setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) + +#' @rdname intersect +#' @export +setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) + +#' @rdname isLocal +#' @export +setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) + +#' @rdname limit +#' @export +setGeneric("limit", function(x, num) {standardGeneric("limit") }) + +#' @rdname sortDF +#' @export +setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) + +#' @rdname schema +#' @export +setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) + +#' @rdname registerTempTable +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + +#' @rdname sampleDF +#' @export +setGeneric("sampleDF", + function(x, withReplacement, fraction, seed) { + standardGeneric("sampleDF") + }) + +#' @rdname saveAsParquetFile +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname saveAsTable +#' @export +setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { + standardGeneric("saveAsTable") +}) + +#' @rdname saveAsTable +#' @export +setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) + +#' @rdname schema +#' @export +setGeneric("schema", function(x) { standardGeneric("schema") }) + +#' @rdname select +#' @export +setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) + +#' @rdname select +#' @export +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname showDF +#' @export +setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) + +#' @rdname sortDF +#' @export +setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) + +#' @rdname subtract +#' @export +setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) + +#' @rdname tojson +#' @export +setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) + +#' @rdname DataFrame +#' @export +setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) + +#' @rdname unionAll +#' @export +setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) + +#' @rdname filter +#' @export +setGeneric("where", function(x, condition) { standardGeneric("where") }) + +#' @rdname withColumn +#' @export +setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) + +#' @rdname withColumnRenamed +#' @export +setGeneric("withColumnRenamed", function(x, existingCol, newCol) { + standardGeneric("withColumnRenamed") }) + + +###################### Column Methods ########################## + +#' @rdname column +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname column +#' @export +setGeneric("asc", function(x) { standardGeneric("asc") }) + +#' @rdname column +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname column +#' @export +setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) + +#' @rdname column +#' @export +setGeneric("contains", function(x, ...) { standardGeneric("contains") }) +#' @rdname column +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname column +#' @export +setGeneric("desc", function(x) { standardGeneric("desc") }) + +#' @rdname column +#' @export +setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) + +#' @rdname column +#' @export +setGeneric("getField", function(x, ...) { standardGeneric("getField") }) + +#' @rdname column +#' @export +setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) + +#' @rdname column +#' @export +setGeneric("isNull", function(x) { standardGeneric("isNull") }) + +#' @rdname column +#' @export +setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) + +#' @rdname column +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname column +#' @export +setGeneric("like", function(x, ...) { standardGeneric("like") }) + +#' @rdname column +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname column +#' @export +setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) + +#' @rdname column +#' @export +setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) + +#' @rdname column +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname column +#' @export +setGeneric("upper", function(x) { standardGeneric("upper") }) + diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R new file mode 100644 index 0000000000..09fc0a7abe --- /dev/null +++ b/R/pkg/R/group.R @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# group.R - GroupedData class and methods implemented in S4 OO classes + +setOldClass("jobj") + +#' @title S4 class that represents a GroupedData +#' @description GroupedDatas can be created using groupBy() on a DataFrame +#' @rdname GroupedData +#' @seealso groupBy +#' +#' @param sgd A Java object reference to the backing Scala GroupedData +#' @export +setClass("GroupedData", + slots = list(sgd = "jobj")) + +setMethod("initialize", "GroupedData", function(.Object, sgd) { + .Object@sgd <- sgd + .Object +}) + +#' @rdname DataFrame +groupedData <- function(sgd) { + new("GroupedData", sgd) +} + + +#' @rdname show +setMethod("show", "GroupedData", + function(object) { + cat("GroupedData\n") + }) + +#' Count +#' +#' Count the number of rows for each group. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @export +#' @examples +#' \dontrun{ +#' count(groupBy(df, "name")) +#' } +setMethod("count", + signature(x = "GroupedData"), + function(x) { + dataFrame(callJMethod(x@sgd, "count")) + }) + +#' Agg +#' +#' Aggregates on the entire DataFrame without groups. +#' The resulting DataFrame will also contain the grouping columns. +#' +#' df2 <- agg(df, = ) +#' df2 <- agg(df, newColName = aggFunction(column)) +#' +#' @param x a GroupedData +#' @return a DataFrame +#' @rdname agg +#' @examples +#' \dontrun{ +#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' +#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' } +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +setMethod("agg", + signature(x = "GroupedData"), + function(x, ...) { + cols = list(...) + stopifnot(length(cols) > 0) + if (is.character(cols[[1]])) { + cols <- varargsToEnv(...) + sdf <- callJMethod(x@sgd, "agg", cols) + } else if (class(cols[[1]]) == "Column") { + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] = alias(cols[[n]], n) + } + } + } + jcols <- lapply(cols, function(c) { c@jc }) + # the GroupedData.agg(col, cols*) API does not contain grouping Column + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", + x@sgd, listToSeq(jcols)) + } else { + stop("agg can only support Column or character") + } + dataFrame(sdf) + }) + + +# sum/mean/avg/min/max +methods <- c("sum", "mean", "avg", "min", "max") + +createMethod <- function(name) { + setMethod(name, + signature(x = "GroupedData"), + function(x, ...) { + sdf <- callJMethod(x@sgd, name, toSeq(...)) + dataFrame(sdf) + }) +} + +createMethods <- function() { + for (name in methods) { + createMethod(name) + } +} + +createMethods() + diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R new file mode 100644 index 0000000000..4180f146b7 --- /dev/null +++ b/R/pkg/R/jobj.R @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# References to objects that exist on the JVM backend +# are maintained using the jobj. + +# Maintain a reference count of Java object references +# This allows us to GC the java object when it is safe +.validJobjs <- new.env(parent = emptyenv()) + +# List of object ids to be removed +.toRemoveJobjs <- new.env(parent = emptyenv()) + +# Check if jobj was created with the current SparkContext +isValidJobj <- function(jobj) { + if (exists(".scStartTime", envir = .sparkREnv)) { + jobj$appId == get(".scStartTime", envir = .sparkREnv) + } else { + FALSE + } +} + +getJobj <- function(objId) { + newObj <- jobj(objId) + if (exists(objId, .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] + 1 + } else { + .validJobjs[[objId]] <- 1 + } + newObj +} + +# Handler for a java object that exists on the backend. +jobj <- function(objId) { + if (!is.character(objId)) { + stop("object id must be a character") + } + # NOTE: We need a new env for a jobj as we can only register + # finalizers for environments or external references pointers. + obj <- structure(new.env(parent = emptyenv()), class = "jobj") + obj$id <- objId + obj$appId <- get(".scStartTime", envir = .sparkREnv) + + # Register a finalizer to remove the Java object when this reference + # is garbage collected in R + reg.finalizer(obj, cleanup.jobj) + obj +} + +#' Print a JVM object reference. +#' +#' This function prints the type and id for an object stored +#' in the SparkR JVM backend. +#' +#' @param x The JVM object reference +#' @param ... further arguments passed to or from other methods +print.jobj <- function(x, ...) { + cls <- callJMethod(x, "getClass") + name <- callJMethod(cls, "getName") + cat("Java ref type", name, "id", x$id, "\n", sep = " ") +} + +cleanup.jobj <- function(jobj) { + if (isValidJobj(jobj)) { + objId <- jobj$id + # If we don't know anything about this jobj, ignore it + if (exists(objId, envir = .validJobjs)) { + .validJobjs[[objId]] <- .validJobjs[[objId]] - 1 + + if (.validJobjs[[objId]] == 0) { + rm(list = objId, envir = .validJobjs) + # NOTE: We cannot call removeJObject here as the finalizer may be run + # in the middle of another RPC. Thus we queue up this object Id to be removed + # and then run all the removeJObject when the next RPC is called. + .toRemoveJobjs[[objId]] <- 1 + } + } + } +} + +clearJobjs <- function() { + valid <- ls(.validJobjs) + rm(list = valid, envir = .validJobjs) + + removeList <- ls(.toRemoveJobjs) + rm(list = removeList, envir = .toRemoveJobjs) +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R new file mode 100644 index 0000000000..c2396c32a7 --- /dev/null +++ b/R/pkg/R/pairRDD.R @@ -0,0 +1,789 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Operations supported on RDDs contains pairs (i.e key, value) + +############ Actions and Transformations ############ + +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +setMethod("lookup", + signature(x = "RDD", key = "ANY"), + function(x, key) { + partitionFunc <- function(part) { + filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))] + lapply(filtered, function(i) { i[[2]] }) + } + valsRDD <- lapplyPartition(x, partitionFunc) + collect(valsRDD) + }) + +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +setMethod("countByKey", + signature(x = "RDD"), + function(x) { + keys <- lapply(x, function(item) { item[[1]] }) + countByValue(keys) + }) + +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +setMethod("keys", + signature(x = "RDD"), + function(x) { + func <- function(k) { + k[[1]] + } + lapply(x, func) + }) + +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +setMethod("values", + signature(x = "RDD"), + function(x) { + func <- function(v) { + v[[2]] + } + lapply(x, func) + }) + +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +setMethod("mapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + func <- function(x) { + list(x[[1]], FUN(x[[2]])) + } + lapply(X, func) + }) + +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +setMethod("flatMapValues", + signature(X = "RDD", FUN = "function"), + function(X, FUN) { + flatMapFunc <- function(x) { + lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) }) + } + flatMap(X, flatMapFunc) + }) + +############ Shuffle Functions ############ + +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +setMethod("partitionBy", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions, partitionFunc = hashCode) { + + #if (missing(partitionFunc)) { + # partitionFunc <- hashCode + #} + + partitionFunc <- cleanClosure(partitionFunc) + serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) + + packageNamesArr <- serialize(.sparkREnv$.packages, + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), function(name) { + get(name, .broadcastNames) }) + jrdd <- getJRDD(x) + + # We create a PairwiseRRDD that extends RDD[(Array[Byte], + # Array[Byte])], where the key is the hashed split, the value is + # the content (key-val pairs). + pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", + callJMethod(jrdd, "rdd"), + as.integer(numPartitions), + serializedHashFuncBytes, + getSerializedMode(x), + packageNamesArr, + as.character(.sparkREnv$libname), + broadcastArr, + callJMethod(jrdd, "classTag")) + + # Create a corresponding partitioner. + rPartitioner <- newJObject("org.apache.spark.HashPartitioner", + as.integer(numPartitions)) + + # Call partitionBy on the obtained PairwiseRDD. + javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") + javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner) + + # Call .values() on the result to get back the final result, the + # shuffled acutal content key-val pairs. + r <- callJMethod(javaPairRDD, "values") + + RDD(r, serializedMode = "byte") + }) + +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +setMethod("groupByKey", + signature(x = "RDD", numPartitions = "integer"), + function(x, numPartitions) { + shuffled <- partitionBy(x, numPartitions) + groupVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + appendList <- function(acc, i) { + addItemToAccumulator(acc, i) + acc + } + makeList <- function(i) { + acc <- initAccumulator() + addItemToAccumulator(acc, i) + acc + } + # Each item in the partition is list of (K, V) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, + appendList, makeList) + }) + # extract out data field + vals <- eapply(vals, + function(i) { + length(i$data) <- i$counter + i$data + }) + # Every key in the environment contains a list + # Convert that to list(K, Seq[V]) + convertEnvsToList(keys, vals) + } + lapplyPartition(shuffled, groupVals) + }) + +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +setMethod("reduceByKey", + signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + function(x, combineFunc, numPartitions) { + reduceVals <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + convertEnvsToList(keys, vals) + } + locallyReduced <- lapplyPartition(x, reduceVals) + shuffled <- partitionBy(locallyReduced, numPartitions) + lapplyPartition(shuffled, reduceVals) + }) + +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +setMethod("reduceByKeyLocally", + signature(x = "RDD", combineFunc = "ANY"), + function(x, combineFunc) { + reducePart <- function(part) { + vals <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(hashCode(item[[1]])) + updateOrCreatePair(item, keys, vals, pred, combineFunc, identity) + }) + list(list(keys, vals)) # return hash to avoid re-compute in merge + } + mergeParts <- function(accum, x) { + pred <- function(item) { + exists(item$hash, accum[[1]]) + } + lapply(ls(x[[1]]), + function(name) { + item <- list(x[[1]][[name]], x[[2]][[name]]) + item$hash <- name + updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity) + }) + accum + } + reduced <- mapPartitions(x, reducePart) + merged <- reduce(reduced, mergeParts) + convertEnvsToList(merged[[1]], merged[[2]]) + }) + +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). + +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("combineByKey", + signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", + mergeCombiners = "ANY", numPartitions = "integer"), + function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { + combineLocally <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) + }) + convertEnvsToList(keys, combiners) + } + locallyCombined <- lapplyPartition(x, combineLocally) + shuffled <- partitionBy(locallyCombined, numPartitions) + mergeAfterShuffle <- function(part) { + combiners <- new.env() + keys <- new.env() + pred <- function(item) exists(item$hash, keys) + lapply(part, + function(item) { + item$hash <- as.character(item[[1]]) + updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) + }) + convertEnvsToList(keys, combiners) + } + lapplyPartition(shuffled, mergeAfterShuffle) + }) + +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +setMethod("aggregateByKey", + signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", + combOp = "ANY", numPartitions = "integer"), + function(x, zeroValue, seqOp, combOp, numPartitions) { + createCombiner <- function(v) { + do.call(seqOp, list(zeroValue, v)) + } + + combineByKey(x, createCombiner, seqOp, combOp, numPartitions) + }) + +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +setMethod("foldByKey", + signature(x = "RDD", zeroValue = "ANY", + func = "ANY", numPartitions = "integer"), + function(x, zeroValue, func, numPartitions) { + aggregateByKey(x, zeroValue, func, func, numPartitions) + }) + +############ Binary Functions ############# + +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +setMethod("join", + signature(x = "RDD", y = "RDD"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + doJoin) + }) + +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +setMethod("leftOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(FALSE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +setMethod("rightOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, FALSE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +setMethod("fullOuterJoin", + signature(x = "RDD", y = "RDD", numPartitions = "integer"), + function(x, y, numPartitions) { + xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) + yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) + + doJoin <- function(v) { + joinTaggedList(v, list(TRUE, TRUE)) + } + + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) + }) + +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +setMethod("cogroup", + "RDD", + function(..., numPartitions) { + rdds <- list(...) + rddsLen <- length(rdds) + for (i in 1:rddsLen) { + rdds[[i]] <- lapply(rdds[[i]], + function(x) { list(x[[1]], list(i, x[[2]])) }) + # TODO(hao): As issue [SparkR-142] mentions, the right value of i + # will not be captured into UDF if getJRDD is not invoked. + # It should be resolved together with that issue. + getJRDD(rdds[[i]]) # Capture the closure. + } + union.rdd <- Reduce(unionRDD, rdds) + group.func <- function(vlist) { + res <- list() + length(res) <- rddsLen + for (x in vlist) { + i <- x[[1]] + acc <- res[[i]] + # Create an accumulator. + if (is.null(acc)) { + acc <- initAccumulator() + } + addItemToAccumulator(acc, x[[2]]) + res[[i]] <- acc + } + lapply(res, function(acc) { + if (is.null(acc)) { + list() + } else { + acc$data + } + }) + } + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + group.func) + }) + +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(x = "RDD"), + function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(x) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only works on atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + # TODO: Use binary search instead of linear search, similar with Spark + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R new file mode 100644 index 0000000000..8a9c0c652c --- /dev/null +++ b/R/pkg/R/serialize.R @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility functions to serialize R objects so they can be read in Java. + +# Type mapping from R to Java +# +# NULL -> Void +# integer -> Int +# character -> String +# logical -> Boolean +# double, numeric -> Double +# raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time +# +# list[T] -> Array[T], where T is one of above mentioned types +# environment -> Map[String, T], where T is a native type +# jobj -> Object, where jobj is an object created in the backend + +writeObject <- function(con, object, writeType = TRUE) { + # NOTE: In R vectors have same type as objects. So we don't support + # passing in vectors as arrays and instead require arrays to be passed + # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + if (writeType) { + writeType(con, type) + } + switch(type, + NULL = writeVoid(con), + integer = writeInt(con, object), + character = writeString(con, object), + logical = writeBoolean(con, object), + double = writeDouble(con, object), + numeric = writeDouble(con, object), + raw = writeRaw(con, object), + list = writeList(con, object), + jobj = writeJobj(con, object), + environment = writeEnv(con, object), + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL +} + +writeJobj <- function(con, value) { + if (!isValidJobj(value)) { + stop("invalid jobj ", value$id) + } + writeString(con, value$id) +} + +writeString <- function(con, value) { + writeInt(con, as.integer(nchar(value) + 1)) + writeBin(value, con, endian = "big") +} + +writeInt <- function(con, value) { + writeBin(as.integer(value), con, endian = "big") +} + +writeDouble <- function(con, value) { + writeBin(value, con, endian = "big") +} + +writeBoolean <- function(con, value) { + # TRUE becomes 1, FALSE becomes 0 + writeInt(con, as.integer(value)) +} + +writeRawSerialize <- function(outputCon, batch) { + outputSer <- serialize(batch, ascii = FALSE, connection = NULL) + writeRaw(outputCon, outputSer) +} + +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + +writeRaw <- function(con, batch) { + writeInt(con, length(batch)) + writeBin(batch, con, endian = "big") +} + +writeType <- function(con, class) { + type <- switch(class, + NULL = "n", + integer = "i", + character = "c", + logical = "b", + double = "d", + numeric = "d", + raw = "r", + list = "l", + jobj = "j", + environment = "e", + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) + writeBin(charToRaw(type), con) +} + +# Used to pass arrays where all the elements are of the same type +writeList <- function(con, arr) { + # All elements should be of same type + elemType <- unique(sapply(arr, function(elem) { class(elem) })) + stopifnot(length(elemType) <= 1) + + # TODO: Empty lists are given type "character" right now. + # This may not work if the Java side expects array of any other type. + if (length(elemType) == 0) { + elemType <- class("somestring") + } + + writeType(con, elemType) + writeInt(con, length(arr)) + + if (length(arr) > 0) { + for (a in arr) { + writeObject(con, a, FALSE) + } + } +} + +# Used to pass in hash maps required on Java side. +writeEnv <- function(con, env) { + len <- length(env) + + writeInt(con, len) + if (len > 0) { + writeList(con, as.list(ls(env))) + vals <- lapply(ls(env), function(x) { env[[x]] }) + writeList(con, as.list(vals)) + } +} + +writeDate <- function(con, date) { + writeString(con, as.character(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + +# Used to serialize in a list of objects where each +# object can be of a different type. Serialization format is +# for each object +writeArgs <- function(con, args) { + if (length(args) > 0) { + for (a in args) { + writeObject(con, a) + } + } +} + +writeStrings <- function(con, stringList) { + writeLines(unlist(stringList), con) +} diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R new file mode 100644 index 0000000000..bc82df01f0 --- /dev/null +++ b/R/pkg/R/sparkR.R @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.sparkREnv <- new.env() + +sparkR.onLoad <- function(libname, pkgname) { + .sparkREnv$libname <- libname +} + +# Utility function that returns TRUE if we have an active connection to the +# backend and FALSE otherwise +connExists <- function(env) { + tryCatch({ + exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) + }, error = function(err) { + return(FALSE) + }) +} + +#' Stop the Spark context. +#' +#' Also terminates the backend this R session is connected to +sparkR.stop <- function() { + env <- .sparkREnv + if (exists(".sparkRCon", envir = env)) { + # cat("Stopping SparkR\n") + if (exists(".sparkRjsc", envir = env)) { + sc <- get(".sparkRjsc", envir = env) + callJMethod(sc, "stop") + rm(".sparkRjsc", envir = env) + } + + if (exists(".backendLaunched", envir = env)) { + callJStatic("SparkRHandler", "stopBackend") + } + + # Also close the connection and remove it from our env + conn <- get(".sparkRCon", envir = env) + close(conn) + + rm(".sparkRCon", envir = env) + rm(".scStartTime", envir = env) + } + + if (exists(".monitorConn", envir = env)) { + conn <- get(".monitorConn", envir = env) + close(conn) + rm(".monitorConn", envir = env) + } + + # Clear all broadcast variables we have + # as the jobj will not be valid if we restart the JVM + clearBroadcastVariables() + + # Clear jobj maps + clearJobjs() +} + +#' Initialize a new Spark Context. +#' +#' This function initializes a new SparkContext. +#' +#' @param master The Spark master URL. +#' @param appName Application name to register with cluster manager +#' @param sparkHome Spark Home directory +#' @param sparkEnvir Named list of environment variables to set on worker nodes. +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. +#' @param sparkJars Character string vector of jar files to pass to the worker nodes. +#' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") +#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g")) +#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="1g"), +#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), +#' c("jarfile1.jar","jarfile2.jar")) +#'} + +sparkR.init <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvir = list(), + sparkExecutorEnv = list(), + sparkJars = "", + sparkRLibDir = "") { + + if (exists(".sparkRjsc", envir = .sparkREnv)) { + cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + return(get(".sparkRjsc", envir = .sparkREnv)) + } + + sparkMem <- Sys.getenv("SPARK_MEM", "512m") + jars <- suppressWarnings(normalizePath(as.character(sparkJars))) + + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + collapseChar <- ":" + uriSep <- "//" + } else { + collapseChar <- ";" + uriSep <- "////" + } + + existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + if (existingPort != "") { + backendPort <- existingPort + } else { + path <- tempfile(pattern = "backend_port") + launchBackend( + args = path, + sparkHome = sparkHome, + jars = jars, + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + backendPort <- readInt(f) + monitorPort <- readInt(f) + close(f) + file.remove(path) + if (length(backendPort) == 0 || backendPort == 0 || + length(monitorPort) == 0 || monitorPort == 0) { + stop("JVM failed to launch") + } + assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".backendLaunched", 1, envir = .sparkREnv) + } + + .sparkREnv$backendPort <- backendPort + tryCatch({ + connectBackend("localhost", backendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) + + if (nchar(sparkHome) != 0) { + sparkHome <- normalizePath(sparkHome) + } + + if (nchar(sparkRLibDir) != 0) { + .sparkREnv$libname <- sparkRLibDir + } + + sparkEnvirMap <- new.env() + for (varname in names(sparkEnvir)) { + sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] + } + + sparkExecutorEnvMap <- new.env() + if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + } + for (varname in names(sparkExecutorEnv)) { + sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] + } + + nonEmptyJars <- Filter(function(x) { x != "" }, jars) + localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + + # Set the start time to identify jobjs + # Seconds resolution is good enough for this purpose, so use ints + assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv) + + assign( + ".sparkRjsc", + callJStatic( + "org.apache.spark.api.r.RRDD", + "createSparkContext", + master, + appName, + as.character(sparkHome), + as.list(localJarPaths), + sparkEnvirMap, + sparkExecutorEnvMap), + envir = .sparkREnv + ) + + sc <- get(".sparkRjsc", envir = .sparkREnv) + + # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy + reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE) + + sc +} + +#' Initialize a new SQLContext. +#' +#' This function creates a SparkContext from an existing JavaSparkContext and +#' then uses it to initialize a new SQLContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#'} + +sparkRSQL.init <- function(jsc) { + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + return(get(".sparkRSQLsc", envir = .sparkREnv)) + } + + sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + jsc) + assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) + sqlCtx +} + +#' Initialize a new HiveContext. +#' +#' This function creates a HiveContext from an existing JavaSparkContext +#' +#' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRHive.init(sc) +#'} + +sparkRHive.init <- function(jsc) { + if (exists(".sparkRHivesc", envir = .sparkREnv)) { + return(get(".sparkRHivesc", envir = .sparkREnv)) + } + + ssc <- callJMethod(jsc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.HiveContext", ssc) + }, error = function(err) { + stop("Spark SQL is not built with Hive support") + }) + + assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) + hiveCtx +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R new file mode 100644 index 0000000000..c337fb0751 --- /dev/null +++ b/R/pkg/R/utils.R @@ -0,0 +1,467 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utilities and Helpers + +# Given a JList, returns an R list containing the same elements, the number +# of which is optionally upper bounded by `logicalUpperBound` (by default, +# return all elements). Takes care of deserializations and type conversions. +convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, + serializedMode = "byte") { + arrSize <- callJMethod(jList, "size") + + # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()): + # each partition is not dense-packed into one Array[Byte], and `arrSize` + # here corresponds to number of logical elements. Thus we can prune here. + if (serializedMode == "string" && !is.null(logicalUpperBound)) { + arrSize <- min(arrSize, logicalUpperBound) + } + + results <- if (arrSize > 0) { + lapply(0:(arrSize - 1), + function(index) { + obj <- callJMethod(jList, "get", as.integer(index)) + + # Assume it is either an R object or a Java obj ref. + if (inherits(obj, "jobj")) { + if (isInstanceOf(obj, "scala.Tuple2")) { + # JavaPairRDD[Array[Byte], Array[Byte]]. + + keyBytes = callJMethod(obj, "_1") + valBytes = callJMethod(obj, "_2") + res <- list(unserialize(keyBytes), + unserialize(valBytes)) + } else { + stop(paste("utils.R: convertJListToRList only supports", + "RDD[Array[Byte]] and", + "JavaPairRDD[Array[Byte], Array[Byte]] for now")) + } + } else { + if (inherits(obj, "raw")) { + if (serializedMode == "byte") { + # RDD[Array[Byte]]. `obj` is a whole partition. + res <- unserialize(obj) + # For serialized datasets, `obj` (and `rRaw`) here corresponds to + # one whole partition dense-packed together. We deserialize the + # whole partition first, then cap the number of elements to be returned. + } else if (serializedMode == "row") { + res <- readRowList(obj) + # For DataFrames that have been converted to RRDDs, we call readRowList + # which will read in each row of the RRDD as a list and deserialize + # each element. + flatten <<- FALSE + # Use global assignment to change the flatten flag. This means + # we don't have to worry about the default argument in other functions + # e.g. collect + } + # TODO: is it possible to distinguish element boundary so that we can + # unserialize only what we need? + if (!is.null(logicalUpperBound)) { + res <- head(res, n = logicalUpperBound) + } + } else { + # obj is of a primitive Java type, is simplified to R's + # corresponding type. + res <- list(obj) + } + } + res + }) + } else { + list() + } + + if (flatten) { + as.list(unlist(results, recursive = FALSE)) + } else { + as.list(results) + } +} + +# Returns TRUE if `name` refers to an RDD in the given environment `env` +isRDD <- function(name, env) { + obj <- get(name, envir = env) + inherits(obj, "RDD") +} + +#' Compute the hashCode of an object +#' +#' Java-style function to compute the hashCode for the given object. Returns +#' an integer value. +#' +#' @details +#' This only works for integer, numeric and character types right now. +#' +#' @param key the object to be hashed +#' @return the hash code as an integer +#' @export +#' @examples +#' hashCode(1L) # 1 +#' hashCode(1.0) # 1072693248 +#' hashCode("1") # 49 +hashCode <- function(key) { + if (class(key) == "integer") { + as.integer(key[[1]]) + } else if (class(key) == "numeric") { + # Convert the double to long and then calculate the hash code + rawVec <- writeBin(key[[1]], con = raw()) + intBits <- packBits(rawToBits(rawVec), "integer") + as.integer(bitwXor(intBits[2], intBits[1])) + } else if (class(key) == "character") { + .Call("stringHashCode", key) + } else { + warning(paste("Could not hash object, returning 0", sep = "")) + as.integer(0) + } +} + +# Create a new RDD with serializedMode == "byte". +# Return itself if already in "byte" format. +serializeToBytes <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "byte") { + ser.rdd <- lapply(rdd, function(x) { x }) + return(ser.rdd) + } else { + return(rdd) + } +} + +# Create a new RDD with serializedMode == "string". +# Return itself if already in "string" format. +serializeToString <- function(rdd) { + if (!inherits(rdd, "RDD")) { + stop("Argument 'rdd' is not an RDD type.") + } + if (getSerializedMode(rdd) != "string") { + ser.rdd <- lapply(rdd, function(x) { toString(x) }) + # force it to create jrdd using "string" + getJRDD(ser.rdd, serializedMode = "string") + return(ser.rdd) + } else { + return(rdd) + } +} + +# Fast append to list by using an accumulator. +# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r +# +# The accumulator should has three fields size, counter and data. +# This function amortizes the allocation cost by doubling +# the size of the list every time it fills up. +addItemToAccumulator <- function(acc, item) { + if(acc$counter == acc$size) { + acc$size <- acc$size * 2 + length(acc$data) <- acc$size + } + acc$counter <- acc$counter + 1 + acc$data[[acc$counter]] <- item +} + +initAccumulator <- function() { + acc <- new.env() + acc$counter <- 0 + acc$data <- list(NULL) + acc$size <- 1 + acc +} + +# Utility function to sort a list of key value pairs +# Used in unit tests +sortKeyValueList <- function(kv_list, decreasing = FALSE) { + keys <- sapply(kv_list, function(x) x[[1]]) + kv_list[order(keys, decreasing = decreasing)] +} + +# Utility function to generate compact R lists from grouped rdd +# Used in Join-family functions +# param: +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +genCompactLists <- function(tagged_list, cnull) { + len <- length(tagged_list) + lists <- list(vector("list", len), vector("list", len)) + index <- list(1, 1) + + for (x in tagged_list) { + tag <- x[[1]] + idx <- index[[tag]] + lists[[tag]][[idx]] <- x[[2]] + index[[tag]] <- idx + 1 + } + + len <- lapply(index, function(x) x - 1) + for (i in (1:2)) { + if (cnull[[i]] && len[[i]] == 0) { + lists[[i]] <- list(NULL) + } else { + length(lists[[i]]) <- len[[i]] + } + } + + lists +} + +# Utility function to merge compact R lists +# Used in Join-family functions +# param: +# left/right Two compact lists ready for Cartesian product +mergeCompactLists <- function(left, right) { + result <- list() + length(result) <- length(left) * length(right) + index <- 1 + for (i in left) { + for (j in right) { + result[[index]] <- list(i, j) + index <- index + 1 + } + } + result +} + +# Utility function to wrapper above two operations +# Used in Join-family functions +# param (same as genCompactLists): +# tagged_list R list generated via groupByKey with tags(1L, 2L, ...) +# cnull Boolean list where each element determines whether the corresponding list should +# be converted to list(NULL) +joinTaggedList <- function(tagged_list, cnull) { + lists <- genCompactLists(tagged_list, cnull) + mergeCompactLists(lists[[1]], lists[[2]]) +} + +# Utility function to reduce a key-value list with predicate +# Used in *ByKey functions +# param +# pair key-value pair +# keys/vals env of key/value with hashes +# updateOrCreatePred predicate function +# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey +# createFn create function for new pair, similar with `createCombiner` @combinebykey +updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) { + # assume hashVal bind to `$hash`, key/val with index 1/2 + hashVal <- pair$hash + key <- pair[[1]] + val <- pair[[2]] + if (updateOrCreatePred(pair)) { + assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals) + } else { + assign(hashVal, do.call(createFn, list(val)), envir = vals) + assign(hashVal, key, envir = keys) + } +} + +# Utility function to convert key&values envs into key-val list +convertEnvsToList <- function(keys, vals) { + lapply(ls(keys), + function(name) { + list(keys[[name]], vals[[name]]) + }) +} + +# Utility function to capture the varargs into environment object +varargsToEnv <- function(...) { + pairs <- as.list(substitute(list(...)))[-1L] + env <- new.env() + for (name in names(pairs)) { + env[[name]] <- pairs[[name]] + } + env +} + +getStorageLevel <- function(newLevel = c("DISK_ONLY", + "DISK_ONLY_2", + "MEMORY_AND_DISK", + "MEMORY_AND_DISK_2", + "MEMORY_AND_DISK_SER", + "MEMORY_AND_DISK_SER_2", + "MEMORY_ONLY", + "MEMORY_ONLY_2", + "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "OFF_HEAP")) { + match.arg(newLevel) + storageLevel <- switch(newLevel, + "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) +} + +# Utility function for functions where an argument needs to be integer but we want to allow +# the user to type (for example) `5` instead of `5L` to avoid a confusing error message. +numToInt <- function(num) { + if (as.integer(num) != num) { + warning(paste("Coercing", as.list(sys.call())[[2]], "to integer.")) + } + as.integer(num) +} + +# create a Seq in JVM +toSeq <- function(...) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) +} + +# create a Seq in JVM from a list +listToSeq <- function(l) { + callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) +} + +# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a +# user defined function (UDF), and to examine variables in the UDF to decide +# if their values should be included in the new function environment. +# param +# node The current AST node in the traversal. +# oldEnv The original function environment. +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { + nodeLen <- length(node) + + if (nodeLen > 1 && typeof(node) == "language") { + # Recursive case: current AST node is an internal node, check for its children. + if (length(node[[1]]) > 1) { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else { # if node[[1]] is length of 1, check for some R special functions. + nodeChar <- as.character(node[[1]]) + if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + for (i in 2:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "<-" || nodeChar == "=" || + nodeChar == "<<-") { # Assignment Ops. + defVar <- node[[2]] + if (length(defVar) == 1 && typeof(defVar) == "symbol") { + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) + } else { + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "function") { # Function definition. + # Add parameter names. + newArgs <- names(node[[2]]) + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) + for (i in 3:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } else if (nodeChar == "$") { # Skip the field. + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) + } else if (nodeChar == "::" || nodeChar == ":::") { + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) + } else { + for (i in 1:nodeLen) { + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) + } + } + } + } else if (nodeLen == 1 && + (typeof(node) == "symbol" || typeof(node) == "language")) { + # Base case: current AST node is a leaf node and a symbol or a function call. + nodeChar <- as.character(node) + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + func.env <- oldEnv + topEnv <- parent.env(.GlobalEnv) + # Search in function environment, and function's enclosing environments + # up to global environment. There is no need to look into package environments + # above the global or namespace environment that is not SparkR below the global, + # as they are assumed to be loaded on workers. + while (!identical(func.env, topEnv)) { + # Namespaces other than "SparkR" will not be searched. + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in + # attached package environments. + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { + obj <- get(nodeChar, envir = func.env, inherits = FALSE) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) + } + assign(nodeChar, obj, envir = newEnv) + break + } + } + + # Continue to search in enclosure. + func.env <- parent.env(func.env) + } + } + } +} + +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined +# outside a UDF, and stores them in the function's environment. +# param +# func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. +# return value +# a new version of func that has an correct environment (closure). +cleanClosure <- function(func, checkedFuncs = new.env()) { + if (is.function(func)) { + newEnv <- new.env(parent = .GlobalEnv) + func.body <- body(func) + oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() + argNames <- names(as.list(args(func))) + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } + # Recursively examine variables in the function body. + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) + environment(func) <- newEnv + } + func +} diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000..80d796d467 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.onLoad <- function(libname, pkgname) { + sparkR.onLoad(libname, pkgname) +} + diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R new file mode 100644 index 0000000000..8fe711b622 --- /dev/null +++ b/R/pkg/inst/profile/general.R @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) +} diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R new file mode 100644 index 0000000000..7a7f203115 --- /dev/null +++ b/R/pkg/inst/profile/shell.R @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +.First <- function() { + home <- Sys.getenv("SPARK_HOME") + .libPaths(c(file.path(home, "R", "lib"), .libPaths())) + Sys.setenv(NOAWT=1) + + library(utils) + library(SparkR) + sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) + assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + cat("\n Welcome to SparkR!") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") +} diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R new file mode 100644 index 0000000000..4bb5f58d83 --- /dev/null +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions on binary files") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("saveAsObjectFile()/objectFile() following textFile() works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + saveAsObjectFile(rdd, fileName2) + rdd <- objectFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + + l <- list(1, 2, 3) + rdd <- parallelize(sc, l) + saveAsObjectFile(rdd, fileName) + rdd <- objectFile(sc, fileName) + expect_equal(collect(rdd), l) + + unlink(fileName, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsObjectFile(counts, fileName2) + counts <- objectFile(sc, fileName2) + + output <- collect(counts) + expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), + list("is", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName1) + unlink(fileName2, recursive = TRUE) +}) + +test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + + rdd1 <- parallelize(sc, "Spark is pretty.") + saveAsObjectFile(rdd1, fileName1) + rdd2 <- parallelize(sc, "Spark is awesome.") + saveAsObjectFile(rdd2, fileName2) + + rdd <- objectFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1, recursive = TRUE) + unlink(fileName2, recursive = TRUE) +}) + diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R new file mode 100644 index 0000000000..c15553ba28 --- /dev/null +++ b/R/pkg/inst/tests/test_binary_function.R @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("binary functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +# File content +mockFile <- c("Spark is pretty.", "Spark is awesome.") + +test_that("union on two RDDs", { + actual <- collect(unionRDD(rdd, rdd)) + expect_equal(actual, as.list(rep(nums, 2))) + + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, c(as.list(nums), mockFile)) + expect_true(getSerializedMode(union.rdd) == "byte") + + rdd<- map(text.rdd, function(x) {x}) + union.rdd <- unionRDD(rdd, text.rdd) + actual <- collect(union.rdd) + expect_equal(actual, as.list(c(mockFile, mockFile))) + expect_true(getSerializedMode(union.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cogroup on two RDDs", { + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + expect_equal(actual, + list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) + + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) + rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + actual <- collect(cogroup.rdd) + + expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R new file mode 100644 index 0000000000..fee91a427d --- /dev/null +++ b/R/pkg/inst/tests/test_broadcast.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("broadcast variables") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rrdd <- parallelize(sc, nums, 2L) + +test_that("using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMatBr <- broadcast(sc, randomMat) + + useBroadcast <- function(x) { + sum(value(randomMatBr) * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) + +test_that("without using broadcast variable", { + randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + + useBroadcast <- function(x) { + sum(randomMat * x) + } + actual <- collect(lapply(rrdd, useBroadcast)) + expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) + expect_equal(actual, expected) +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R new file mode 100644 index 0000000000..e4aab37436 --- /dev/null +++ b/R/pkg/inst/tests/test_context.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R new file mode 100644 index 0000000000..8152b448d0 --- /dev/null +++ b/R/pkg/inst/tests/test_includePackage.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("include R packages") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Partitioned data +nums <- 1:2 +rdd <- parallelize(sc, nums, 2L) + +test_that("include inside function", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + suppressPackageStartupMessages(library(plyr)) + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) + +test_that("use include package", { + # Only run the test if plyr is installed. + if ("plyr" %in% rownames(installed.packages())) { + suppressPackageStartupMessages(library(plyr)) + generateData <- function(x) { + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + result + } + + includePackage(sc, plyr) + data <- lapplyPartition(rdd, generateData) + actual <- collect(data) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R new file mode 100644 index 0000000000..fff028657d --- /dev/null +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("parallelize() and collect()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) +strPairs <- list(list(strList, strList), list(strList, strList)) + +# JavaSparkContext handle +jsc <- sparkR.init() + +# Tests + +test_that("parallelize() on simple vectors and lists returns an RDD", { + numVectorRDD <- parallelize(jsc, numVector, 1) + numVectorRDD2 <- parallelize(jsc, numVector, 10) + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + + rdds <- c(numVectorRDD, + numVectorRDD2, + numListRDD, + numListRDD2, + strVectorRDD, + strVectorRDD2, + strListRDD, + strListRDD2) + + for (rdd in rdds) { + expect_true(inherits(rdd, "RDD")) + expect_true(.hasSlot(rdd, "jrdd") + && inherits(rdd@jrdd, "jobj") + && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) + } +}) + +test_that("collect(), following a parallelize(), gives back the original collections", { + numVectorRDD <- parallelize(jsc, numVector, 10) + expect_equal(collect(numVectorRDD), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(collect(numListRDD), as.list(numList)) + expect_equal(collect(numListRDD2), as.list(numList)) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(collect(strVectorRDD), as.list(strVector)) + expect_equal(collect(strVectorRDD2), as.list(strVector)) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(collect(strListRDD), as.list(strList)) + expect_equal(collect(strListRDD2), as.list(strList)) +}) + +test_that("regression: collect() following a parallelize() does not drop elements", { + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 + collLen <- 10 + numPart <- 6 + expected <- runif(collLen) + actual <- collect(parallelize(jsc, expected, numPart)) + expect_equal(actual, as.list(expected)) +}) + +test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + # use the pairwise logical to indicate pairwise data + numPairsRDDD1 <- parallelize(jsc, numPairs, 1) + numPairsRDDD2 <- parallelize(jsc, numPairs, 2) + numPairsRDDD3 <- parallelize(jsc, numPairs, 3) + expect_equal(collect(numPairsRDDD1), numPairs) + expect_equal(collect(numPairsRDDD2), numPairs) + expect_equal(collect(numPairsRDDD3), numPairs) + # can also leave out the parameter name, if the params are supplied in order + strPairsRDDD1 <- parallelize(jsc, strPairs, 1) + strPairsRDDD2 <- parallelize(jsc, strPairs, 2) + expect_equal(collect(strPairsRDDD1), strPairs) + expect_equal(collect(strPairsRDDD2), strPairs) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R new file mode 100644 index 0000000000..f75e0817b9 --- /dev/null +++ b/R/pkg/inst/tests/test_rdd.R @@ -0,0 +1,644 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic RDD functions") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +nums <- 1:10 +rdd <- parallelize(sc, nums, 2L) + +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +test_that("get number of partitions in RDD", { + expect_equal(numPartitions(rdd), 2) + expect_equal(numPartitions(intRdd), 2) +}) + +test_that("first on RDD", { + expect_true(first(rdd) == 1) + newrdd <- lapply(rdd, function(x) x + 1) + expect_true(first(newrdd) == 2) +}) + +test_that("count and length on RDD", { + expect_equal(count(rdd), 10) + expect_equal(length(rdd), 10) +}) + +test_that("count by values and keys", { + mods <- lapply(rdd, function(x) { x %% 3 }) + actual <- countByValue(mods) + expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + actual <- countByKey(intRdd) + expected <- list(list(2L, 2L), list(1L, 2L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("lapply on RDD", { + multiples <- lapply(rdd, function(x) { 2 * x }) + actual <- collect(multiples) + expect_equal(actual, as.list(nums * 2)) +}) + +test_that("lapplyPartition on RDD", { + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("mapPartitions on RDD", { + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) + actual <- collect(sums) + expect_equal(actual, list(15, 40)) +}) + +test_that("flatMap() on RDDs", { + flat <- flatMap(intRdd, function(x) { list(x, x) }) + actual <- collect(flat) + expect_equal(actual, rep(intPairs, each=2)) +}) + +test_that("filterRDD on RDD", { + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list(2, 4, 6, 8, 10)) + + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) + actual <- collect(filtered.rdd) + expect_equal(actual, list(list(1L, -1))) + + # Filter out all elements. + filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) + actual <- collect(filtered.rdd) + expect_equal(actual, list()) +}) + +test_that("lookup on RDD", { + vals <- lookup(intRdd, 1L) + expect_equal(vals, list(-1, 200)) + + vals <- lookup(intRdd, 3L) + expect_equal(vals, list()) +}) + +test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + rdd2 <- rdd + for (i in 1:12) + rdd2 <- lapplyPartitionsWithIndex( + rdd2, function(split, part) { + part <- as.list(unlist(part) * split + i) + }) + rdd2 <- lapply(rdd2, function(x) x + x) + actual <- collect(rdd2) + expected <- list(24, 24, 24, 24, 24, + 168, 170, 172, 174, 176) + expect_equal(actual, expected) +}) + +test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + # RDD + rdd2 <- rdd + # PipelinedRDD + rdd2 <- lapplyPartitionsWithIndex( + rdd2, + function(split, part) { + part <- as.list(unlist(part) * split) + }) + + cache(rdd2) + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + persist(rdd2, "MEMORY_AND_DISK") + expect_true(rdd2@env$isCached) + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + + unpersist(rdd2) + expect_false(rdd2@env$isCached) + + setCheckpointDir(sc, "checkpoints") + checkpoint(rdd2) + expect_true(rdd2@env$isCheckpointed) + + rdd2 <- lapply(rdd2, function(x) x) + expect_false(rdd2@env$isCached) + expect_false(rdd2@env$isCheckpointed) + + # make sure the data is collectable + collect(rdd2) + + unlink("checkpoints") +}) + +test_that("reduce on RDD", { + sum <- reduce(rdd, "+") + expect_equal(sum, 55) + + # Also test with an inline function + sumInline <- reduce(rdd, function(x, y) { x + y }) + expect_equal(sumInline, 55) +}) + +test_that("lapply with dependency", { + fa <- 5 + multiples <- lapply(rdd, function(x) { fa * x }) + actual <- collect(multiples) + + expect_equal(actual, as.list(nums * 5)) +}) + +test_that("lapplyPartitionsWithIndex on RDDs", { + func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + expect_equal(actual, list(list(0, 15), list(1, 40))) + + pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) + partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } + mkTup <- function(splitIndex, part) { list(splitIndex, part) } + actual <- collect(lapplyPartitionsWithIndex( + partitionBy(pairsRDD, 2L, partitionByParity), + mkTup), + FALSE) + expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), + list(1, list(list(4, 8))))) +}) + +test_that("sampleRDD() on RDDs", { + expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) +}) + +test_that("takeSample() on RDDs", { + # ported from RDDSuite.scala, modified seeds + data <- parallelize(sc, 1:100, 2L) + for (seed in 4:5) { + s <- takeSample(data, FALSE, 20L, seed) + expect_equal(length(s), 20L) + expect_equal(length(unique(s)), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, FALSE, 200L, seed) + expect_equal(length(s), 100L) + expect_equal(length(unique(s)), 100L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 20L, seed) + expect_equal(length(s), 20L) + for (elem in s) { + expect_true(elem >= 1 && elem <= 100) + } + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 100L, seed) + expect_equal(length(s), 100L) + # Chance of getting all distinct elements is astronomically low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } + for (seed in 4:5) { + s <- takeSample(data, TRUE, 200L, seed) + expect_equal(length(s), 200L) + # Chance of getting all distinct elements is still quite low, so test we + # got < 100 + expect_true(length(unique(s)) < 100L) + } +}) + +test_that("mapValues() on pairwise RDDs", { + multiples <- mapValues(intRdd, function(x) { x * 2 }) + actual <- collect(multiples) + expected <- lapply(intPairs, function(x) { + list(x[[1]], x[[2]] * 2) + }) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("flatMapValues() on pairwise RDDs", { + l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + actual <- collect(flatMapValues(l, function(x) { x })) + expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + + # Generate x to x+1 for every value + actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + expect_equal(actual, + list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), + list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) +}) + +test_that("reduceByKeyLocally() on PairwiseRDDs", { + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, 6), list(1.1, 3)))) + + pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3), + list("bb", 5)), 4L) + actual <- reduceByKeyLocally(pairs, "+") + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5)))) +}) + +test_that("distinct() on RDDs", { + nums.rep2 <- rep(1:10, 2) + rdd.rep2 <- parallelize(sc, nums.rep2, 2L) + uniques <- distinct(rdd.rep2) + actual <- sort(unlist(collect(uniques))) + expect_equal(actual, nums) +}) + +test_that("maximum() on RDDs", { + max <- maximum(rdd) + expect_equal(max, 10) +}) + +test_that("minimum() on RDDs", { + min <- minimum(rdd) + expect_equal(min, 1) +}) + +test_that("sumRDD() on RDDs", { + sum <- sumRDD(rdd) + expect_equal(sum, 55) +}) + +test_that("keyBy on RDDs", { + func <- function(x) { x*x } + keys <- keyBy(rdd, func) + actual <- collect(keys) + expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) +}) + +test_that("repartition/coalesce on RDDs", { + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements + + # repartition + r1 <- repartition(rdd, 2) + expect_equal(numPartitions(r1), 2L) + count <- length(collectPartition(r1, 0L)) + expect_true(count >= 8 && count <= 12) + + r2 <- repartition(rdd, 6) + expect_equal(numPartitions(r2), 6L) + count <- length(collectPartition(r2, 0L)) + expect_true(count >=0 && count <= 4) + + # coalesce + r3 <- coalesce(rdd, 1) + expect_equal(numPartitions(r3), 1L) + count <- length(collectPartition(r3, 0L)) + expect_equal(count, 20) +}) + +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) + actual <- collect(sortedRdd2) + expect_equal(actual, as.list(nums)) +}) + +test_that("takeOrdered() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l)))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- takeOrdered(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l)))[1:3]) +}) + +test_that("top() on RDDs", { + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) + rdd <- parallelize(sc, l) + actual <- top(rdd, 6L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) + + l <- list("e", "d", "c", "d", "a") + rdd <- parallelize(sc, l) + actual <- top(rdd, 3L) + expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3]) +}) + +test_that("fold() on RDDs", { + actual <- fold(rdd, 0, "+") + expect_equal(actual, Reduce("+", nums, 0)) + + rdd <- parallelize(sc, list()) + actual <- fold(rdd, 0, "+") + expect_equal(actual, 0) +}) + +test_that("aggregateRDD() on RDDs", { + rdd <- parallelize(sc, list(1, 2, 3, 4)) + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(10, 4)) + + rdd <- parallelize(sc, list()) + actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) + expect_equal(actual, list(0, 0)) +}) + +test_that("zipWithUniqueId() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 3), list("c", 1), + list("d", 4), list("e", 2)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithUniqueId(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("zipWithIndex() on RDDs", { + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) + + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) + actual <- collect(zipWithIndex(rdd)) + expected <- list(list("a", 0), list("b", 1), list("c", 2), + list("d", 3), list("e", 4)) + expect_equal(actual, expected) +}) + +test_that("glom() on RDD", { + rdd <- parallelize(sc, as.list(1:4), 2L) + actual <- collect(glom(rdd)) + expect_equal(actual, list(list(1, 2), list(3, 4))) +}) + +test_that("keys() on RDDs", { + keys <- keys(intRdd) + actual <- collect(keys) + expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) +}) + +test_that("values() on RDDs", { + values <- values(intRdd) + actual <- collect(values) + expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) +}) + +test_that("pipeRDD() on RDDs", { + actual <- collect(pipeRDD(rdd, "more")) + expected <- as.list(as.character(1:10)) + expect_equal(actual, expected) + + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) + actual <- collect(pipeRDD(trailed.rdd, "sort")) + expected <- list("", "1", "2", "3") + expect_equal(actual, expected) + + rev.nums <- 9:0 + rev.rdd <- parallelize(sc, rev.nums, 2L) + actual <- collect(pipeRDD(rev.rdd, "sort")) + expected <- as.list(as.character(c(5:9, 0:4))) + expect_equal(actual, expected) +}) + +test_that("zipRDD() on RDDs", { + rdd1 <- parallelize(sc, 0:4, 2) + rdd2 <- parallelize(sc, 1000:1004, 2) + actual <- collect(zipRDD(rdd1, rdd2)) + expect_equal(actual, + list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipRDD(rdd, rdd)) + expected <- lapply(mockFile, function(x) { list(x ,x) }) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipRDD(rdd1, rdd)) + expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) + expect_equal(actual, expected) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(zipRDD(rdd, rdd1)) + expected <- lapply(mockFile, function(x) { list(x, x) }) + expect_equal(actual, expected) + + unlink(fileName) +}) + +test_that("join() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(join(rdd1, rdd2, 2L)) + expect_equal(actual, list()) +}) + +test_that("leftOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) +}) + +test_that("rightOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("fullOuterJoin() on pairwise RDDs", { + rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) + rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(expected)) + + rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) + rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + + rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) + rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) +}) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) + + rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) + numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + sortedRdd2 <- sortByKey(numPairsRdd2) + actual <- collect(sortedRdd2) + expect_equal(actual, numPairs) + + # sort by string keys + l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) + rdd3 <- parallelize(sc, l, 2L) + sortedRdd3 <- sortByKey(rdd3) + actual <- collect(sortedRdd3) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # test on the boundary cases + + # boundary case 1: the RDD to be sorted has only 1 partition + rdd4 <- parallelize(sc, l, 1L) + sortedRdd4 <- sortByKey(rdd4) + actual <- collect(sortedRdd4) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 2: the sorted RDD has only 1 partition + rdd5 <- parallelize(sc, l, 2L) + sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) + actual <- collect(sortedRdd5) + expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) + + # boundary case 3: the RDD to be sorted has only 1 element + l2 <- list(list("a", 1)) + rdd6 <- parallelize(sc, l2, 2L) + sortedRdd6 <- sortByKey(rdd6) + actual <- collect(sortedRdd6) + expect_equal(actual, l2) + + # boundary case 4: the RDD to be sorted has 0 element + l3 <- list() + rdd7 <- parallelize(sc, l3, 2L) + sortedRdd7 <- sortByKey(rdd7) + actual <- collect(sortedRdd7) + expect_equal(actual, l3) +}) + +test_that("collectAsMap() on a pairwise RDD", { + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = 2, `3` = 4)) + + rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(a = 1, b = 2)) + + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) + + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) + vals <- collectAsMap(rdd) + expect_equal(vals, list(`1` = "a", `2` = "b")) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R new file mode 100644 index 0000000000..d1da8232ae --- /dev/null +++ b/R/pkg/inst/tests/test_shuffle.R @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("partitionBy, groupByKey, reduceByKey etc.") + +# JavaSparkContext handle +sc <- sparkR.init() + +# Data +intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) +intRdd <- parallelize(sc, intPairs, 2L) + +doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200)) +doubleRdd <- parallelize(sc, doublePairs, 2L) + +numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1), + list(3L, 0)) +numPairsRdd <- parallelize(sc, numPairs, length(numPairs)) + +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ") +strListRDD <- parallelize(sc, strList, 4) + +test_that("groupByKey for integers", { + grouped <- groupByKey(intRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("groupByKey for doubles", { + grouped <- groupByKey(doubleRdd, 2L) + + actual <- collect(grouped) + + expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for ints", { + reduced <- reduceByKey(intRdd, "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("reduceByKey for doubles", { + reduced <- reduceByKey(doubleRdd, "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for ints", { + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) + + actual <- collect(reduced) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("combineByKey for doubles", { + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("aggregateByKey", { + # test aggregateByKey for int keys + rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test aggregateByKey for string keys + rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) + + zeroValue <- list(0, 0) + seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } + combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + + actual <- collect(aggregatedRDD) + + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + +test_that("foldByKey", { + # test foldByKey for int keys + folded <- foldByKey(intRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(2L, 101), list(1L, 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for double keys + folded <- foldByKey(doubleRdd, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list(1.5, 199), list(2.5, 101)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for string keys + stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) + + stringKeyRDD <- parallelize(sc, stringKeyPairs) + folded <- foldByKey(stringKeyRDD, 0, "+", 2L) + + actual <- collect(folded) + + expected <- list(list("b", 101), list("a", 199)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) + + # test foldByKey for empty pair RDD + rdd <- parallelize(sc, list()) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list() + expect_equal(actual, expected) + + # test foldByKey for RDD with only 1 pair + rdd <- parallelize(sc, list(list(1, 1))) + folded <- foldByKey(rdd, 0, "+", 2L) + actual <- collect(folded) + expected <- list(list(1, 1)) + expect_equal(actual, expected) +}) + +test_that("partitionBy() partitions data correctly", { + # Partition by magnitude + partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } + + resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + + expected_first <- list(list(1, 100), list(2, 200)) # key < 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("partitionBy works with dependencies", { + kOne <- 1 + partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } + + # Partition by parity + resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + + # keys even; 100 %% 2 == 0 + expected_first <- list(list(2, 200), list(4, -1)) + # keys odd; 3 %% 2 == 1 + expected_second <- list(list(1, 100), list(3, 1), list(3, 0)) + actual_first <- collectPartition(resultRDD, 0L) + actual_second <- collectPartition(resultRDD, 1L) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) + +test_that("test partitionBy with string keys", { + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + resultRDD <- partitionBy(wordCount, 2L) + expected_first <- list(list("Dexter", 1), list("Dexter", 1)) + expected_second <- list(list("and", 1), list("and", 1)) + + actual_first <- Filter(function(item) { item[[1]] == "Dexter" }, + collectPartition(resultRDD, 0L)) + actual_second <- Filter(function(item) { item[[1]] == "and" }, + collectPartition(resultRDD, 1L)) + + expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first)) + expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second)) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R new file mode 100644 index 0000000000..cf5cf6d169 --- /dev/null +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -0,0 +1,695 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlCtx <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") +writeLines(mockLines, jsonPath) + +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + +test_that("jsonFile() on a local file returns a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_true(count(rdd) == 3) + df <- jsonRDD(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- jsonRDD(sqlCtx, rdd2) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlCtx, "table1") + uncacheTable(sqlCtx, "table1") + clearCache(sqlCtx) + dropTempTable(sqlCtx, "table1") +}) + +test_that("test tableNames and tables", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + expect_true(length(tableNames(sqlCtx)) == 1) + df <- tables(sqlCtx) + expect_true(count(df) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + expect_true(inherits(newdf, "DataFrame")) + expect_true(count(newdf) == 1) + dropTempTable(sqlCtx, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", "overwrite") + dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + saveDF(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_true(count(sql(sqlCtx, "select * from table1")) == 5) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlCtx, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_true(count(sql(sqlCtx, "select * from table1")) == 2) + expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlCtx, "table1") +}) + +test_that("table() returns a new DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + registerTempTable(df, "table1") + tabledf <- table(sqlCtx, "table1") + expect_true(inherits(tabledf, "DataFrame")) + expect_true(count(tabledf) == 3) + dropTempTable(sqlCtx, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toRDD(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(count(testRDD) == 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- jsonFile(sqlCtx, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_true(inherits(unioned, "RDD")) + expect_true(SparkR:::getSerializedMode(unioned) == "byte") + expect_true(collect(unioned)[[2]]$name == "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_true(inherits(unionByte, "RDD")) + expect_true(SparkR:::getSerializedMode(unionByte) == "byte") + expect_true(collect(unionByte)[[1]] == 1) + expect_true(collect(unionByte)[[12]]$name == "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_true(inherits(unionString, "RDD")) + expect_true(SparkR:::getSerializedMode(unionString) == "byte") + expect_true(collect(unionString)[[1]] == "Michael") + expect_true(collect(unionString)[[5]]$name == "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern="spark-test", fileext=".tmp") + df <- jsonFile(sqlCtx, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_true(inherits(objectIn, "RDD")) + expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_true(inherits(testRDD, "RDD")) + collected <- collect(testRDD) + expect_true(collected[[1]]$name == "Michael") + expect_true(collected[[2]]$newCol == "35") +}) + +test_that("collect() returns a data.frame", { + df <- jsonFile(sqlCtx, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_true(names(rdf)[1] == "age") + expect_true(nrow(rdf) == 3) + expect_true(ncol(rdf) == 2) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- jsonFile(sqlCtx, jsonPath) + dfLimited <- limit(df, 2) + expect_true(inherits(dfLimited, "DataFrame")) + expect_true(count(dfLimited) == 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- jsonFile(sqlCtx, jsonPath) + expect_true(nrow(collect(df)) == nrow(take(df, 10))) + expect_true(ncol(collect(df)) == ncol(take(df, 10))) +}) + +test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_true(inherits(second, "RDD")) + expect_true(count(second) == 3) + expect_true(collect(second)[[2]]$age == 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- jsonFile(sqlCtx, jsonPath) + testSchema <- schema(df) + expect_true(length(testSchema$fields()) == 2) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") + expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") + expect_true(testSchema$fields()[[1]]$name() == "age") + + testTypes <- dtypes(df) + expect_true(length(testTypes[[1]]) == 2) + expect_true(testTypes[[1]][1] == "age") + + testCols <- columns(df) + expect_true(length(testCols) == 2) + expect_true(testCols[2] == "name") + + testNames <- names(df) + expect_true(length(testNames) == 2) + expect_true(testNames[2] == "name") +}) + +test_that("head() and first() return the correct data", { + df <- jsonFile(sqlCtx, jsonPath) + testHead <- head(df) + expect_true(nrow(testHead) == 3) + expect_true(ncol(testHead) == 2) + + testHead2 <- head(df, 2) + expect_true(nrow(testHead2) == 2) + expect_true(ncol(testHead2) == 2) + + testFirst <- first(df) + expect_true(nrow(testFirst) == 1) +}) + +test_that("distinct() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + + df <- jsonFile(sqlCtx, jsonPathWithDup) + uniques <- distinct(df) + expect_true(inherits(uniques, "DataFrame")) + expect_true(count(uniques) == 3) +}) + +test_that("sampleDF on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sampled <- sampleDF(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_true(inherits(sampled, "DataFrame")) + sampled2 <- sampleDF(df, FALSE, 0.1) + expect_true(count(sampled2) < 3) +}) + +test_that("select operators", { + df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + expect_true(inherits(df$name, "Column")) + expect_true(inherits(df[[2]], "Column")) + expect_true(inherits(df[["age"]], "Column")) + + expect_true(inherits(df[,1], "DataFrame")) + expect_equal(columns(df[,1]), c("name")) + expect_equal(columns(df[,"age"]), c("age")) + df2 <- df[,c("age", "name")] + expect_true(inherits(df2, "DataFrame")) + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- select(df, "name") + expect_true(columns(df1) == c("name")) + expect_true(count(df1) == 3) + + df2 <- select(df, df$age) + expect_true(columns(df2) == c("age")) + expect_true(count(df2) == 3) +}) + +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + +test_that("column calculation", { + df <- jsonFile(sqlCtx, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_true(names(d) == c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("load() from json file", { + df <- loadDF(sqlCtx, jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) +}) + +test_that("save() as parquet file", { + df <- loadDF(sqlCtx, jsonPath, "json") + saveDF(df, parquetPath, "parquet", mode="overwrite") + df2 <- loadDF(sqlCtx, parquetPath, "parquet") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) +}) + +test_that("test HiveContext", { + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 3) + df2 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df2, "DataFrame")) + expect_true(count(df2) == 3) + + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + saveAsTable(df, "json", "json", "append", path = jsonPath2) + df3 <- sql(hiveCtx, "select * from json") + expect_true(inherits(df3, "DataFrame")) + expect_true(count(df3) == 6) +}) + +test_that("column operators", { + c <- SparkR:::col("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) +}) + +test_that("column functions", { + c <- SparkR:::col("a") + c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) + c3 <- lower(c) + upper(c) + first(c) + last(c) + c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") +}) + +test_that("string operators", { + df <- jsonFile(sqlCtx, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") +}) + +test_that("group by", { + df <- jsonFile(sqlCtx, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_true(1 == count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_true(1 == count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_true(inherits(gd, "GroupedData")) + df2 <- count(gd) + expect_true(inherits(df2, "DataFrame")) + expect_true(3 == count(df2)) + + df3 <- agg(gd, age = "sum") + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + + df3 <- agg(gd, age = sum(df$age)) + expect_true(inherits(df3, "DataFrame")) + expect_true(3 == count(df3)) + expect_equal(columns(df3), c("name", "age")) + + df4 <- sum(gd, "age") + expect_true(inherits(df4, "DataFrame")) + expect_true(3 == count(df4)) + expect_true(3 == count(mean(gd, "age"))) + expect_true(3 == count(max(gd, "age"))) +}) + +test_that("sortDF() and orderBy() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + sorted <- sortDF(df, df$age) + expect_true(collect(sorted)[1,2] == "Michael") + + sorted2 <- sortDF(df, "name") + expect_true(collect(sorted2)[2,"age"] == 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_true(collect(sorted3)[2, "age"] == 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_true(first(sorted4)$name == "Michael") + expect_true(collect(sorted4)[3,"name"] == "Andy") +}) + +test_that("filter() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + filtered <- filter(df, "age > 20") + expect_true(count(filtered) == 1) + expect_true(collect(filtered)$name == "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_true(count(filtered2) == 2) + expect_true(collect(filtered2)$age[2] == 19) +}) + +test_that("join() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- jsonFile(sqlCtx, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_true(count(joined) == 12) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_true(count(joined2) == 3) + + joined3 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_true(count(joined3) == 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_true(count(joined4) == 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- jsonFile(sqlCtx, jsonPath) + testRDD <- toJSON(df) + expect_true(inherits(testRDD, "RDD")) + expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ") +}) + +test_that("isLocal()", { + df <- jsonFile(sqlCtx, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), subtract(), and intersect() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath2) + df2 <- loadDF(sqlCtx, jsonPath2, "json") + + unioned <- sortDF(unionAll(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(unioned) == 6) + expect_true(first(unioned)$name == "Michael") + + subtracted <- sortDF(subtract(df, df2), desc(df$age)) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(subtracted) == 2) + expect_true(first(subtracted)$name == "Justin") + + intersected <- sortDF(intersect(df, df2), df$age) + expect_true(inherits(unioned, "DataFrame")) + expect_true(count(intersected) == 1) + expect_true(first(intersected)$name == "Andy") +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + +test_that("saveDF() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath) + expect_true(inherits(parquetDF, "DataFrame")) + expect_equal(count(df), count(parquetDF)) +}) + +test_that("parquetFile works with multiple input paths", { + df <- jsonFile(sqlCtx, jsonPath) + saveDF(df, parquetPath, "parquet", mode="overwrite") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + saveDF(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + expect_true(inherits(parquetDF, "DataFrame")) + expect_true(count(parquetDF) == count(df)*2) +}) + +unlink(parquetPath) +unlink(jsonPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R new file mode 100644 index 0000000000..7f4c7c315d --- /dev/null +++ b/R/pkg/inst/tests/test_take.R @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("tests RDD function take()") + +# Mock data +numVector <- c(-10:97) +numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10) +strVector <- c("Dexter Morgan: I suppose I should be upset, even feel", + "violated, but I'm not. No, in fact, I think this is a friendly", + "message, like \"Hey, wanna play?\" and yes, I want to play. ", + "I really, really do.") +strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", + "other times it helps me control the chaos.", + "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ", + "raising me. But they're both dead now. I didn't kill them. Honest.") + +# JavaSparkContext handle +jsc <- sparkR.init() + +test_that("take() gives back the original elements in correct count and order", { + numVectorRDD <- parallelize(jsc, numVector, 10) + # case: number of elements to take is less than the size of the first partition + expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + # case: number of elements to take is the same as the size of the first partition + expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + # case: number of elements to take is greater than all elements + expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + + numListRDD <- parallelize(jsc, numList, 1) + numListRDD2 <- parallelize(jsc, numList, 4) + expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) + expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) + expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(take(numListRDD2, 999), numList) + + strVectorRDD <- parallelize(jsc, strVector, 2) + strVectorRDD2 <- parallelize(jsc, strVector, 3) + expect_equal(take(strVectorRDD, 4), as.list(strVector)) + expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + + strListRDD <- parallelize(jsc, strList, 4) + strListRDD2 <- parallelize(jsc, strList, 1) + expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + + expect_true(length(take(strListRDD, 0)) == 0) + expect_true(length(take(strVectorRDD, 0)) == 0) + expect_true(length(take(numListRDD, 0)) == 0) + expect_true(length(take(numVectorRDD, 0)) == 0) +}) + diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R new file mode 100644 index 0000000000..7bb3e80031 --- /dev/null +++ b/R/pkg/inst/tests/test_textFile.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("the textFile() function") + +# JavaSparkContext handle +sc <- sparkR.init() + +mockFile = c("Spark is pretty.", "Spark is awesome.") + +test_that("textFile() on a local file returns an RDD", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_true(inherits(rdd, "RDD")) + expect_true(count(rdd) > 0) + expect_true(count(rdd) == 2) + + unlink(fileName) +}) + +test_that("textFile() followed by a collect() returns the same content", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName) +}) + +test_that("textFile() word count works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + output <- collect(counts) + expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), + list("Spark", 2)) + expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) + + unlink(fileName) +}) + +test_that("several transformations on RDD created by textFile()", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) # RDD + for (i in 1:10) { + # PipelinedRDD initially created from RDD + rdd <- lapply(rdd, function(x) paste(x, x)) + } + collect(rdd) + + unlink(fileName) +}) + +test_that("textFile() followed by a saveAsTextFile() returns the same content", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + saveAsTextFile(rdd, fileName2) + rdd <- textFile(sc, fileName2) + expect_equal(collect(rdd), as.list(mockFile)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("saveAsTextFile() on a parallelized list works as expected", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + l <- list(1, 2, 3) + rdd <- parallelize(sc, l) + saveAsTextFile(rdd, fileName) + rdd <- textFile(sc, fileName) + expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + + unlink(fileName) +}) + +test_that("textFile() and saveAsTextFile() word count works as expected", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName1) + + rdd <- textFile(sc, fileName1) + + words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] }) + wordCount <- lapply(words, function(word) { list(word, 1L) }) + + counts <- reduceByKey(wordCount, "+", 2L) + + saveAsTextFile(counts, fileName2) + rdd <- textFile(sc, fileName2) + + output <- collect(rdd) + expected <- list(list("awesome.", 1), list("Spark", 2), + list("pretty.", 1), list("is", 2)) + expectedStr <- lapply(expected, function(x) { toString(x) }) + expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("textFile() on multiple paths", { + fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines("Spark is pretty.", fileName1) + writeLines("Spark is awesome.", fileName2) + + rdd <- textFile(sc, c(fileName1, fileName2)) + expect_true(count(rdd) == 2) + + unlink(fileName1) + unlink(fileName2) +}) + +test_that("Pipelined operations on RDDs created using textFile", { + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + + lengths <- lapply(rdd, function(x) { length(x) }) + expect_equal(collect(lengths), list(1, 1)) + + lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) + expect_equal(collect(lengthsPipelined), list(11, 11)) + + lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) + expect_equal(collect(lengths30), list(31, 31)) + + lengths20 <- lapply(lengths, function(x) { x + 20 }) + expect_equal(collect(lengths20), list(21, 21)) + + unlink(fileName) +}) + diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R new file mode 100644 index 0000000000..9c5bb42793 --- /dev/null +++ b/R/pkg/inst/tests/test_utils.R @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in utils.R") + +# JavaSparkContext handle +sc <- sparkR.init() + +test_that("convertJListToRList() gives back (deserializes) the original JLists + of strings and integers", { + # It's hard to manually create a Java List using rJava, since it does not + # support generics well. Instead, we rely on collect() returning a + # JList. + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 1L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, nums) + + strs <- as.list("hello", "spark") + rdd <- parallelize(sc, strs, 2L) + jList <- callJMethod(rdd@jrdd, "collect") + rList <- convertJListToRList(jList, flatten = TRUE) + expect_equal(rList, strs) +}) + +test_that("serializeToBytes on RDD", { + # File content + mockFile <- c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + text.rdd <- textFile(sc, fileName) + expect_true(getSerializedMode(text.rdd) == "string") + ser.rdd <- serializeToBytes(text.rdd) + expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_true(getSerializedMode(ser.rdd) == "byte") + + unlink(fileName) +}) + +test_that("cleanClosure on R functions", { + y <- c(1, 2, 3) + g <- function(x) { x + 1 } + f <- function(x) { g(x) + y } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # y, g + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + # Test for nested enclosures and package variables. + env2 <- new.env() + funcEnv <- new.env(parent = env2) + f <- function(x) { log(g(x) + y) } + environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 2) # "min" should not be included + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + actual <- get("g", envir = env, inherits = FALSE) + expect_equal(actual, g) + + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 + g <- function(x) { x + y } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. + newG <- get("g", envir = env, inherits = FALSE) + env <- environment(newG) + expect_equal(length(ls(env)), 1) + actual <- get("y", envir = env, inherits = FALSE) + expect_equal(actual, y) + + # Test for function (and variable) definitions. + f <- function(x) { + g <- function(y) { y * 2 } + g(x) + } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. + + # Test for overriding variables in base namespace (Issue: SparkR-196). + nums <- as.list(1:10) + rdd <- parallelize(sc, nums, 2L) + t = 4 # Override base::t in .GlobalEnv. + f <- function(x) { x > t } + newF <- cleanClosure(f) + env <- environment(newF) + expect_equal(ls(env), "t") + expect_equal(get("t", envir = env, inherits = FALSE), t) + actual <- collect(lapply(rdd, f)) + expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) + expect_equal(actual, expected) + + # Test for broadcast variables. + a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + aBroadcast <- broadcast(sc, a) + normMultiply <- function(x) { norm(aBroadcast$value) * x } + newnormMultiply <- SparkR:::cleanClosure(normMultiply) + env <- environment(newnormMultiply) + expect_equal(ls(env), "aBroadcast") + expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R new file mode 100644 index 0000000000..3584b418a7 --- /dev/null +++ b/R/pkg/inst/worker/daemon.R @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker daemon + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") + +# preload SparkR package, speedup worker +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) + +while (TRUE) { + ready <- socketSelect(list(inputCon)) + if (ready) { + port <- SparkR:::readInt(inputCon) + # There is a small chance that it could be interrupted by signal, retry one time + if (length(port) == 0) { + port <- SparkR:::readInt(inputCon) + if (length(port) == 0) { + cat("quitting daemon\n") + quit(save = "no") + } + } + p <- parallel:::mcfork() + if (inherits(p, "masterProcess")) { + close(inputCon) + Sys.setenv(SPARKR_WORKER_PORT = port) + source(script) + # Set SIGUSR1 so that child can exit + tools::pskill(Sys.getpid(), tools::SIGUSR1) + parallel:::mcexit(0L) + } + } +} diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R new file mode 100644 index 0000000000..c6542928e8 --- /dev/null +++ b/R/pkg/inst/worker/worker.R @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Worker class + +rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +# Set libPaths to include SparkR package as loadNamespace needs this +# TODO: Figure out if we can avoid this by not loading any objects that require +# SparkR namespace +.libPaths(c(rLibDir, .libPaths())) +suppressPackageStartupMessages(library(SparkR)) + +port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) +inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") +outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") + +# read the index of the current partition inside the RDD +partition <- SparkR:::readInt(inputCon) + +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) + +# Include packages as required +packageNames <- unserialize(SparkR:::readRaw(inputCon)) +for (pkg in packageNames) { + suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) +} + +# read function dependencies +funcLen <- SparkR:::readInt(inputCon) +computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) +env <- environment(computeFunc) +parent.env(env) <- .GlobalEnv # Attach under global environment. + +# Read and set broadcast variables +numBroadcastVars <- SparkR:::readInt(inputCon) +if (numBroadcastVars > 0) { + for (bcast in seq(1:numBroadcastVars)) { + bcastId <- SparkR:::readInt(inputCon) + value <- unserialize(SparkR:::readRaw(inputCon)) + setBroadcastValue(bcastId, value) + } +} + +# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int +# as number of partitions to create. +numPartitions <- SparkR:::readInt(inputCon) + +isEmpty <- SparkR:::readInt(inputCon) + +if (isEmpty != 0) { + + if (numPartitions == -1) { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- as.list(readLines(inputCon)) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + output <- computeFunc(partition, data) + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + SparkR:::writeStrings(outputCon, output) + } + } else { + if (deserializer == "byte") { + # Now read as many characters as described in funcLen + data <- SparkR:::readDeserialize(inputCon) + } else if (deserializer == "string") { + data <- readLines(inputCon) + } else if (deserializer == "row") { + data <- SparkR:::readDeserializeRows(inputCon) + } + + res <- new.env() + + # Step 1: hash the data to an environment + hashTupleToEnvir <- function(tuple) { + # NOTE: execFunction is the hash function here + hashVal <- computeFunc(tuple[[1]]) + bucket <- as.character(hashVal %% numPartitions) + acc <- res[[bucket]] + # Create a new accumulator + if (is.null(acc)) { + acc <- SparkR:::initAccumulator() + } + SparkR:::addItemToAccumulator(acc, tuple) + res[[bucket]] <- acc + } + invisible(lapply(data, hashTupleToEnvir)) + + # Step 2: write out all of the environment as key-value pairs. + for (name in ls(res)) { + SparkR:::writeInt(outputCon, 2L) + SparkR:::writeInt(outputCon, as.integer(name)) + # Truncate the accumulator list to the number of elements we have + length(res[[name]]$data) <- res[[name]]$counter + SparkR:::writeRawSerialize(outputCon, res[[name]]$data) + } + } +} + +# End of output +if (serializer %in% c("byte", "row")) { + SparkR:::writeInt(outputCon, 0L) +} + +close(outputCon) +close(inputCon) diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile new file mode 100644 index 0000000000..a55a56fe80 --- /dev/null +++ b/R/pkg/src/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.so string_hash_code.c + +clean: + rm -f *.o + rm -f *.so + +.PHONY: all clean diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win new file mode 100644 index 0000000000..aa486d8228 --- /dev/null +++ b/R/pkg/src/Makefile.win @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +all: sharelib + +sharelib: string_hash_code.c + R CMD SHLIB -o SparkR.dll string_hash_code.c + +clean: + rm -f *.o + rm -f *.dll + +.PHONY: all clean diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c new file mode 100644 index 0000000000..e3274b9a0c --- /dev/null +++ b/R/pkg/src/string_hash_code.c @@ -0,0 +1,49 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* + * A C function for R extension which implements the Java String hash algorithm. + * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function + * + */ + +#include +#include + +/* for compatibility with R before 3.1 */ +#ifndef IS_SCALAR +#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1) +#endif + +SEXP stringHashCode(SEXP string) { + const char* str; + R_xlen_t len, i; + int hashCode = 0; + + if (!IS_SCALAR(string, STRSXP)) { + error("invalid input"); + } + + str = CHAR(asChar(string)); + len = XLENGTH(asChar(string)); + + for (i = 0; i < len; i++) { + hashCode = (hashCode << 5) - hashCode + *str++; + } + + return ScalarInteger(hashCode); +} diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R new file mode 100644 index 0000000000..4f8a1ed2d8 --- /dev/null +++ b/R/pkg/tests/run-all.R @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) +library(SparkR) + +test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh new file mode 100755 index 0000000000..e82ad0ba2c --- /dev/null +++ b/R/run-tests.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FWDIR="$(cd `dirname $0`; pwd)" + +FAILED=0 +LOGFILE=$FWDIR/unit-tests.out +rm -f $LOGFILE + +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +if [[ $FAILED != 0 ]]; then + cat $LOGFILE + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi diff --git a/bin/sparkR b/bin/sparkR new file mode 100755 index 0000000000..8c918e2b09 --- /dev/null +++ b/bin/sparkR @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +source "$SPARK_HOME"/bin/load-spark-env.sh + +function usage() { + if [ -n "$1" ]; then + echo $1 + fi + echo "Usage: ./bin/sparkR [options]" 1>&2 + "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit $2 +} +export -f usage + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage +fi + +exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd new file mode 100644 index 0000000000..d7b60183ca --- /dev/null +++ b/bin/sparkR.cmd @@ -0,0 +1,23 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This is the entry point for running SparkR. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0sparkR2.cmd %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd new file mode 100644 index 0000000000..e47f22c730 --- /dev/null +++ b/bin/sparkR2.cmd @@ -0,0 +1,26 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem Figure out where the Spark framework is installed +set SPARK_HOME=%~dp0.. + +call %SPARK_HOME%\bin\load-spark-env.cmd + + +call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* diff --git a/core/pom.xml b/core/pom.xml index 6cd1965ec3..e80829b7a7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -442,4 +442,55 @@ + + + Windows + + + Windows + + + + \ + .bat + + + + unix + + + unix + + + + / + .sh + + + + sparkr + + + + org.codehaus.mojo + exec-maven-plugin + 1.3.2 + + + sparkr-pkg + compile + + exec + + + + + ..${path.separator}R${path.separator}install-dev${script.extension} + + + + + + + diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala new file mode 100644 index 0000000000..3a2c94bd9d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetSocketAddress, ServerSocket} +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} + +import org.apache.spark.Logging + +/** + * Netty-based backend server that is used to communicate between R and Java. + */ +private[spark] class RBackend { + + private[this] var channelFuture: ChannelFuture = null + private[this] var bootstrap: ServerBootstrap = null + private[this] var bossGroup: EventLoopGroup = null + + def init(): Int = { + bossGroup = new NioEventLoopGroup(2) + val workerGroup = bossGroup + val handler = new RBackendHandler(this) + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast("frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", handler) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } + +} + +private[spark] object RBackend extends Logging { + def main(args: Array[String]): Unit = { + if (args.length < 1) { + System.err.println("Usage: RBackend ") + System.exit(-1) + } + val sparkRBackend = new RBackend() + try { + // bind to random port + val boundPort = sparkRBackend.init() + val serverSocket = new ServerSocket(0, 1) + val listenPort = serverSocket.getLocalPort() + + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.writeInt(listenPort) + dos.close() + f.renameTo(new File(path)) + + // wait for the end of stdin, then exit + new Thread("wait for socket to close") { + setDaemon(true) + override def run(): Unit = { + // any un-catched exception will also shutdown JVM + val buf = new Array[Byte](1024) + // shutdown JVM if R does not connect back in 10 seconds + serverSocket.setSoTimeout(10000) + try { + val inSocket = serverSocket.accept() + serverSocket.close() + // wait for the end of socket, closed if R process die + inSocket.getInputStream().read(buf) + } finally { + sparkRBackend.close() + System.exit(0) + } + } + }.start() + + sparkRBackend.run() + } catch { + case e: IOException => + logError("Server shutting down: failed with exception ", e) + sparkRBackend.close() + System.exit(1) + } + System.exit(0) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala new file mode 100644 index 0000000000..0075d96371 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.HashMap + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.api.r.SerDe._ + +/** + * Handler for RBackend + * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use + * this across connections ? + */ +@Sharable +private[r] class RBackendHandler(server: RBackend) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = readBoolean(dis) + val objId = readString(dis) + val methodName = readString(dis) + val numArgs = readInt(dis) + + if (objId == "SparkRHandler") { + methodName match { + case "stopBackend" => + writeInt(dos, 0) + writeType(dos, "void") + server.close() + case "rm" => + try { + val t = readObjectType(dis) + assert(t == 'c') + val objToRemove = readString(dis) + JVMObjectTracker.remove(objToRemove) + writeInt(dos, 0) + writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + writeInt(dos, -1) + } + case _ => dos.writeInt(-1) + } + } else { + handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + } + + val reply = bos.toByteArray + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + try { + val cls = if (isStatic) { + Class.forName(objId) + } else { + JVMObjectTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + val args = readArgs(numArgs, dis) + + val methods = cls.getMethods + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val methods = selectedMethods.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + } + if (methods.isEmpty) { + logWarning(s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + val ret = methods.head.invoke(obj, args:_*) + + // Write status bit + writeInt(dos, 0) + writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args:_*) + + writeInt(dos, 0) + writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Exception => + logError(s"$methodName on $objId failed", e) + writeInt(dos, -1) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 to numArgs - 1) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + return false + } + } + true + } +} + +/** + * Helper singleton that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] object JVMObjectTracker { + + // TODO: This map should be thread-safe if we want to support multiple + // connections at the same time + private[this] val objMap = new HashMap[String, Object] + + // TODO: We support only one connection now, so an integer is fine. + // Investigate using use atomic integer in the future. + private[this] var objCounter: Int = 0 + + def getObject(id: String): Object = { + objMap(id) + } + + def get(id: String): Option[Object] = { + objMap.get(id) + } + + def put(obj: Object): String = { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + + def remove(id: String): Option[Object] = { + objMap.remove(id) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala new file mode 100644 index 0000000000..5fa4d483b8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.ServerSocket +import java.util.{Map => JMap} + +import scala.collection.JavaConversions._ +import scala.io.Source +import scala.reflect.ClassTag +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( + parent: RDD[T], + numPartitions: Int, + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Broadcast[Object]]) + extends RDD[U](parent) with Logging { + override def getPartitions: Array[Partition] = parent.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + + // The parent may be also an RRDD, so we should launch it first. + val parentIterator = firstParent[T].iterator(partition, context) + + // we expect two connections + val serverSocket = new ServerSocket(0, 2) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(rLibDir, listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + val dataStream = openDataStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + printOut.println(elem) + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def openDataStream(input: InputStream): Closeable + + protected def read(): U +} + +/** + * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R. + * This is used by SparkR's shuffle operations. + */ +private class PairwiseRRDD[T: ClassTag]( + parent: RDD[T], + numPartitions: Int, + hashFunc: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, (Int, Array[Byte])]( + parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): (Int, Array[Byte]) = { + try { + val length = dataStream.readInt() + + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null // End of input + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +/** + * An RDD that stores serialized R objects as Array[Byte]. + */ +private class RRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, Array[Byte]]( + parent, -1, func, deserializer, serializer, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: DataInputStream = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new DataInputStream(input) + dataStream + } + + override protected def read(): Array[Byte] = { + try { + val length = dataStream.readInt() + + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj, 0, length) + obj + case _ => null + } + } catch { + case eof: EOFException => { + throw new SparkException("R worker exited unexpectedly (crashed)", eof) + } + } + } + + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * An RDD that stores R objects as Array[String]. + */ +private class StringRRDD[T: ClassTag]( + parent: RDD[T], + func: Array[Byte], + deserializer: String, + packageNames: Array[Byte], + rLibDir: String, + broadcastVars: Array[Object]) + extends BaseRRDD[T, String]( + parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + private var dataStream: BufferedReader = _ + + override protected def openDataStream(input: InputStream): Closeable = { + dataStream = new BufferedReader(new InputStreamReader(input)) + dataStream + } + + override protected def read(): String = { + try { + dataStream.readLine() + } catch { + case e: IOException => { + throw new SparkException("R worker exited unexpectedly (crashed)", e) + } + } + } + + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + def createSparkContext( + master: String, + appName: String, + sparkHome: String, + jars: Array[String], + sparkEnvirMap: JMap[Object, Object], + sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + + val sparkConf = new SparkConf().setAppName(appName) + .setSparkHome(sparkHome) + .setJars(jars) + + // Override `master` if we have a user-specified value + if (master != "") { + sparkConf.setMaster(master) + } else { + // If conf has no master set it to "local" to maintain + // backwards compatibility + sparkConf.setIfMissing("spark.master", "local") + } + + for ((name, value) <- sparkEnvirMap) { + sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + } + for ((name, value) <- sparkExecutorEnvMap) { + sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + } + + new JavaSparkContext(sparkConf) + } + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { + val rCommand = "Rscript" + val rOptions = "--vanilla" + val rExecScript = rLibDir + "/SparkR/worker/" + script + val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(rLibDir, daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(rLibDir, port, "worker.R") + } + } + + /** + * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is + * called from R. + */ + def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala new file mode 100644 index 0000000000..ccb2a371f4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} + +import scala.collection.JavaConversions._ + +/** + * Utility functions to serialize, deserialize objects to / from R + */ +private[spark] object SerDe { + + // Type mapping from R to Java + // + // NULL -> void + // integer -> Int + // character -> String + // logical -> Boolean + // double, numeric -> Double + // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time + // + // list[T] -> Array[T], where T is one of above mentioned types + // environment -> Map[String, T], where T is a native type + // jobj -> Object, where jobj is an object created in the backend + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + def readTypedObject( + dis: DataInputStream, + dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => JVMObjectTracker.getObject(readString(dis)) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + val bytesRead = in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + val asciiBytes = new Array[Byte](len) + in.readFully(asciiBytes) + assert(asciiBytes(len - 1) == 0) + val str = new String(asciiBytes.dropRight(1).map(_.toChar)) + str + } + + def readBoolean(in: DataInputStream): Boolean = { + val intVal = in.readInt() + if (intVal == 0) false else true + } + + def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + def readTime(in: DataInputStream): Time = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + + def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesType = readObjectType(in) + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + mapAsJavaMap(keys.zip(values).toMap) + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Methods to write out data from Java to R + // + // Type mapping from Java to R + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Double -> double + // Long -> double + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "long" | "java.lang.Long" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + val intValue = if (value) 1 else 0 + out.writeInt(intValue) + } + + def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + + // NOTE: Only works for ASCII right now + def writeString(out: DataOutputStream, value: String): Unit = { + val len = value.length + out.writeInt(len + 1) // For the \0 + out.writeBytes(value) + out.writeByte(0) + } + + def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = JVMObjectTracker.put(value) + writeString(out, objId) + } + + def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private[r] object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala new file mode 100644 index 0000000000..e99779f299 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io._ +import java.util.concurrent.{Semaphore, TimeUnit} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.api.r.RBackend +import org.apache.spark.util.RedirectThread + +/** + * Main class used to launch SparkR applications using spark-submit. It executes R as a + * subprocess and then has it connect back to the JVM to access system properties etc. + */ +object RRunner { + def main(args: Array[String]): Unit = { + val rFile = PythonRunner.formatPath(args(0)) + + val otherArgs = args.slice(1, args.length) + + // Time to wait for SparkR backend to initialize in seconds + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val rCommand = "Rscript" + + // Check if the file path exists. + // If not, change directory to current working directory for YARN cluster mode + val rF = new File(rFile) + val rFileNormalized = if (!rF.exists()) { + new Path(rFile).getName + } else { + rFile + } + + // Launch a SparkR backend server for the R process to connect to; this will let it see our + // Java system properties etc. + val sparkRBackend = new RBackend() + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { + override def run() { + sparkRBackendPort = sparkRBackend.init() + initialized.release() + sparkRBackend.run() + } + } + + sparkRBackendThread.start() + // Wait for RBackend initialization to finish + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + // Launch R + val returnCode = try { + val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val env = builder.environment() + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + val sparkHome = System.getenv("SPARK_HOME") + env.put("R_PROFILE_USER", + Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator)) + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + val process = builder.start() + + new RedirectThread(process.getInputStream, System.out, "redirect R output").start() + + process.waitFor() + } finally { + sparkRBackend.close() + } + System.exit(returnCode) + } else { + System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.exit(-1) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 660307d19e..60bc243ebf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -77,6 +77,7 @@ object SparkSubmit { // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -284,6 +285,13 @@ object SparkSubmit { } } + // Require all R files to be local + if (args.isR && !isYarnCluster) { + if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { + printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + } + } + // The following modes are not supported or applicable (clusterManager, deployMode) match { case (MESOS, CLUSTER) => @@ -291,6 +299,9 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") + case (STANDALONE, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + + "applications on standalone clusters.") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => @@ -317,11 +328,32 @@ object SparkSubmit { } } - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython && isYarnCluster) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) + // If we're running a R app, set the main class to our specific R runner + if (args.isR && deployMode == CLIENT) { + if (args.primaryResource == SPARKR_SHELL) { + args.mainClass = "org.apache.spark.api.r.RBackend" + } else { + // If a R file is provided, add it to the child arguments and list of files to deploy. + // Usage: RRunner
[app arguments] + args.mainClass = "org.apache.spark.deploy.RRunner" + args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + + if (isYarnCluster) { + // In yarn-cluster mode for a python app, add primary resource and pyFiles to files + // that can be distributed with the job + if (args.isPython) { + args.files = mergeFileLists(args.files, args.primaryResource) + args.files = mergeFileLists(args.files, args.pyFiles) + } + + // In yarn-cluster mode for a R app, add primary resource to files + // that can be distributed with the job + if (args.isR) { + args.files = mergeFileLists(args.files, args.primaryResource) + } } // Special flag to avoid deprecation warnings at the client @@ -405,8 +437,8 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" - // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !args.isPython) { + // For python and R files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !args.isPython && !args.isR) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) @@ -447,6 +479,10 @@ object SparkSubmit { childArgs += ("--py-files", pyFilesNames) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") + } else if (args.isR) { + val mainFile = new Path(args.primaryResource).getName + childArgs += ("--primary-r-file", mainFile) + childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) @@ -591,15 +627,15 @@ object SparkSubmit { /** * Return whether the given primary resource represents a user jar. */ - private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) + private[deploy] def isUserJar(res: String): Boolean = { + !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res) } /** * Return whether the given primary resource represents a shell. */ - private[deploy] def isShell(primaryResource: String): Boolean = { - primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL + private[deploy] def isShell(res: String): Boolean = { + (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL) } /** @@ -619,12 +655,19 @@ object SparkSubmit { /** * Return whether the given primary resource requires running python. */ - private[deploy] def isPython(primaryResource: String): Boolean = { - primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL + private[deploy] def isPython(res: String): Boolean = { + res != null && res.endsWith(".py") || res == PYSPARK_SHELL } - private[deploy] def isInternal(primaryResource: String): Boolean = { - primaryResource == SPARK_INTERNAL + /** + * Return whether the given primary resource requires running R. + */ + private[deploy] def isR(res: String): Boolean = { + res != null && res.endsWith(".R") || res == SPARKR_SHELL + } + + private[deploy] def isInternal(res: String): Boolean = { + res == SPARK_INTERNAL } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 6eb73c4347..03ecf3fd99 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var isR: Boolean = false var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null @@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .getOrElse(sparkProperties.get("spark.executor.instances").orNull) // Try to set main class from JAR if no --class argument is given - if (mainClass == null && !isPython && primaryResource != null) { + if (mainClass == null && !isPython && !isR && primaryResource != null) { val uri = new URI(primaryResource) val uriScheme = uri.getScheme() @@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)") + SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") } - if (mainClass == null && !isPython) { + if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") } if (pyFiles != null && !isPython) { @@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S opt } isPython = SparkSubmit.isPython(opt) + isR = SparkSubmit.isR(opt) false } diff --git a/dev/run-tests b/dev/run-tests index 561d7fc9e7..1b6cf78b5d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -236,3 +236,18 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests + +echo "" +echo "=========================================================================" +echo "Running SparkR tests" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS + +if [ $(command -v R) ]; then + ./R/install-dev.sh + ./R/run-tests.sh +else + echo "Ignoring SparkR tests as R was not found in PATH" +fi + diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 8ab6db6925..154e01255b 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14 readonly BLOCK_MIMA=15 readonly BLOCK_SPARK_UNIT_TESTS=16 readonly BLOCK_PYSPARK_UNIT_TESTS=17 +readonly BLOCK_SPARKR_UNIT_TESTS=18 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f10aa6b59e..f6372835a6 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then + failing_test="SparkR unit tests" else failing_test="some tests" fi diff --git a/docs/README.md b/docs/README.md index 3773ea25c8..5852f972a0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -58,13 +58,19 @@ phase, use the following sytax: We use Sphinx to generate Python API docs, so you will need to install it by running `sudo pip install sphinx`. -## API Docs (Scaladoc and Sphinx) +## knitr, devtools + +SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate +documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a +R console. + +## API Docs (Scaladoc, Sphinx, roxygen2) You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory. Similarly, you can build just the PySpark docs by running `make html` from the SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. +public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a @@ -72,5 +78,5 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs [Sphinx](http://sphinx-doc.org/). -NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 +NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 2e88b30936..b92c75f90b 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -84,6 +84,7 @@
  • Scala
  • Java
  • Python
  • +
  • R
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3c626a0b7f..0ea3f8eab4 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -78,5 +78,18 @@ if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1') puts "cp -r python/docs/_build/html/. docs/api/python" cp_r("python/docs/_build/html/.", "docs/api/python") - cd("..") + # Build SparkR API docs + puts "Moving to R directory and building roxygen docs." + cd("R") + puts `./create-docs.sh` + + puts "Moving back into home dir." + cd("../") + + puts "Making directory api/R" + mkdir_p "docs/api/R" + + puts "cp -r R/pkg/html/. docs/api/R" + cp_r("R/pkg/html/.", "docs/api/R") + end diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R new file mode 100644 index 0000000000..6e6b5cb937 --- /dev/null +++ b/examples/src/main/r/kmeans.R @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# Logistic regression in Spark. +# Note: unlike the example in Scala, a point here is represented as a vector of +# doubles. + +parseVectors <- function(lines) { + lines <- strsplit(as.character(lines) , " ", fixed = TRUE) + list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]]))) +} + +dist.fun <- function(P, C) { + apply( + C, + 1, + function(x) { + colSums((t(P) - x)^2) + } + ) +} + +closestPoint <- function(P, C) { + max.col(-dist.fun(P, C)) +} +# Main program + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: kmeans ") + q("no") +} + +sc <- sparkR.init(appName = "RKMeans") +K <- as.integer(args[[2]]) +convergeDist <- as.double(args[[3]]) + +lines <- textFile(sc, args[[1]]) +points <- cache(lapplyPartition(lines, parseVectors)) +# kPoints <- take(points, K) +kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L)) +tempDist <- 1.0 + +while (tempDist > convergeDist) { + closest <- lapplyPartition( + lapply(points, + function(p) { + cp <- closestPoint(p, kPoints); + mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE) + }), + function(x) {do.call(c, x) + }) + + pointStats <- reduceByKey(closest, + function(p1, p2) { + t(colSums(rbind(p1, p2))) + }, + 2L) + + newPoints <- do.call( + rbind, + collect(lapply(pointStats, + function(tup) { + point.sum <- tup[[2]][, -1] + point.count <- tup[[2]][, 1] + point.sum/point.count + }))) + + D <- dist.fun(kPoints, newPoints) + tempDist <- sum(D[cbind(1:3, max.col(-D))]) + kPoints <- newPoints + cat("Finished iteration (delta = ", tempDist, ")\n") +} + +cat("Final centers:\n") +writeLines(unlist(lapply(kPoints, paste, collapse = " "))) diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R new file mode 100644 index 0000000000..c864a4232d --- /dev/null +++ b/examples/src/main/r/linear_solver_mnist.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2 + +library(SparkR) +library(Matrix) + +args <- commandArgs(trailing = TRUE) + +# number of random features; default to 1100 +D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100) +# number of partitions for training dataset +trainParts <- 12 +# dimension of digits +d <- 784 +# number of test examples +NTrain <- 60000 +# number of training examples +NTest <- 10000 +# scale of features +gamma <- 4e-4 + +sc <- sparkR.init(appName = "SparkR-LinearSolver") + +# You can also use HDFS path to speed things up: +# hdfs:///train-mnist-dense-with-labels.data +file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts) + +W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d)) +b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D)) +broadcastW <- broadcast(sc, W) +broadcastB <- broadcast(sc, b) + +includePackage(sc, Matrix) +numericLines <- lapplyPartitionsWithIndex(file, + function(split, part) { + matList <- sapply(part, function(line) { + as.numeric(strsplit(line, ",", fixed=TRUE)[[1]]) + }, simplify=FALSE) + mat <- Matrix(ncol=d+1, data=unlist(matList, F, F), + sparse=T, byrow=T) + mat + }) + +featureLabels <- cache(lapplyPartition( + numericLines, + function(part) { + label <- part[,1] + mat <- part[,-1] + ones <- rep(1, nrow(mat)) + features <- cos( + mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB)))) + onesMat <- Matrix(ones) + featuresPlus <- cBind(features, onesMat) + labels <- matrix(nrow=nrow(mat), ncol=10, data=-1) + for (i in 1:nrow(mat)) { + labels[i, label[i]] <- 1 + } + list(label=labels, features=featuresPlus) + })) + +FTF <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$features + }), flatten=F)) + +FTY <- Reduce("+", collect(lapplyPartition(featureLabels, + function(part) { + t(part$features) %*% part$label + }), flatten=F)) + +# solve for the coefficient matrix +C <- solve(FTF, FTY) + +test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data", + header=F), sparse=T)) +testData <- test[,-1] +testLabels <- matrix(ncol=1, test[,1]) + +err <- 0 + +# contstruct the feature maps for all examples from this digit +featuresTest <- cos(testData %*% t(value(broadcastW)) + + (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB)))) +featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest))) + +# extract the one vs. all assignment +results <- featuresTest %*% C +labelsGot <- apply(results, 1, which.max) +err <- sum(testLabels != labelsGot) / nrow(testLabels) + +cat("\nFinished running. The error rate is: ", err, ".\n") diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R new file mode 100644 index 0000000000..2a86aa9816 --- /dev/null +++ b/examples/src/main/r/logistic_regression.R @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 3) { + print("Usage: logistic_regression ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "LogisticRegressionR") +iterations <- as.integer(args[[2]]) +D <- as.integer(args[[3]]) + +readPartition <- function(part){ + part = strsplit(part, " ", fixed = T) + list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]]))) +} + +# Read data points and convert each partition to a matrix +points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition)) + +# Initialize w to a random value +w <- runif(n=D, min = -1, max = 1) +cat("Initial w: ", w, "\n") + +# Compute logistic regression gradient for a matrix of data points +gradient <- function(partition) { + partition = partition[[1]] + Y <- partition[, 1] # point labels (first column of input file) + X <- partition[, -1] # point coordinates + + # For each point (x, y), compute gradient function + dot <- X %*% w + logit <- 1 / (1 + exp(-Y * dot)) + grad <- t(X) %*% ((logit - 1) * Y) + list(grad) +} + +for (i in 1:iterations) { + cat("On iteration ", i, "\n") + w <- w - reduce(lapplyPartition(points, gradient), "+") +} + +cat("Final w: ", w, "\n") diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R new file mode 100644 index 0000000000..aa7a833e14 --- /dev/null +++ b/examples/src/main/r/pi.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +sc <- sparkR.init(appName = "PiR") + +slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2) + +n <- 100000 * slices + +piFunc <- function(elem) { + rands <- runif(n = 2, min = -1, max = 1) + val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0) + val +} + + +piFuncVec <- function(elems) { + message(length(elems)) + rands1 <- runif(n = length(elems), min = -1, max = 1) + rands2 <- runif(n = length(elems), min = -1, max = 1) + val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0) + sum(val) +} + +rdd <- parallelize(sc, 1:n, slices) +count <- reduce(lapplyPartition(rdd, piFuncVec), sum) +cat("Pi is roughly", 4.0 * count / n, "\n") +cat("Num elements in RDD ", count(rdd), "\n") diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R new file mode 100644 index 0000000000..b734cb0ecf --- /dev/null +++ b/examples/src/main/r/wordcount.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: wordcount ") + q("no") +} + +# Initialize Spark context +sc <- sparkR.init(appName = "RwordCount") +lines <- textFile(sc, args[[1]]) + +words <- flatMap(lines, + function(line) { + strsplit(line, " ")[[1]] + }) +wordCount <- lapply(words, function(word) { list(word, 1L) }) + +counts <- reduceByKey(wordCount, "+", 2L) +output <- collect(counts) + +for (wordcount in output) { + cat(wordcount[[1]], ": ", wordcount[[2]], "\n") +} diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 9b04732afe..f4ebc25bdd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -274,14 +274,14 @@ class CommandBuilderUtils { } /** - * Quotes a string so that it can be used in a command string and be parsed back into a single - * argument by python's "shlex.split()" function. - * + * Quotes a string so that it can be used in a command string. * Basically, just add simple escapes. E.g.: * original single argument : ab "cd" ef * after: "ab \"cd\" ef" + * + * This can be parsed back into a single argument by python's "shlex.split()" function. */ - static String quoteForPython(String s) { + static String quoteForCommandString(String s) { StringBuilder quoted = new StringBuilder().append('"'); for (int i = 0; i < s.length(); i++) { int cp = s.codePointAt(i); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 91dcf70f10..a73c9c87e3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -17,14 +17,9 @@ package org.apache.spark.launcher; +import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; +import java.util.*; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -53,6 +48,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell"; + /** + * Name of the app resource used to identify the SparkR shell. The command line parser expects + * the resource name to be the very first argument to spark-submit in this case. + * + * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit + * (see sparkR.R), and can cause this code to enter into an infinite loop. + */ + static final String SPARKR_SHELL = "sparkr-shell-main"; + + /** + * This is the actual resource name that identifies the SparkR shell to SparkSubmit. + */ + static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling @@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = PYSPARK_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); + } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL_RESOURCE; + submitArgs = args.subList(1, args.size()); } else { this.allowsMixedArguments = false; } @@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { public List buildCommand(Map env) throws IOException { if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { return buildPySparkShellCommand(env); + } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); } @@ -213,26 +228,14 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { return buildCommand(env); } - // When launching the pyspark shell, the spark-submit arguments should be stored in the - // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable - // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS. checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); - Properties props = loadPropertiesFile(); - mergeEnvPathList(env, getLibPathEnvName(), - firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); - - // Store spark-submit arguments in an environment variable, since there's no way to pass - // them to shell.py on the comand line. - StringBuilder submitArgs = new StringBuilder(); - for (String arg : buildSparkSubmitArgs()) { - if (submitArgs.length() > 0) { - submitArgs.append(" "); - } - submitArgs.append(quoteForPython(arg)); - } - env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString()); + // When launching the pyspark shell, the spark-submit arguments should be stored in the + // PYSPARK_SUBMIT_ARGS env variable. + constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); + // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, + // followed by PYSPARK_DRIVER_PYTHON_OPTS. List pyargs = new ArrayList(); pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); @@ -243,6 +246,44 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { return pyargs; } + private List buildSparkRCommand(Map env) throws IOException { + if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { + appResource = appArgs.get(0); + appArgs.remove(0); + return buildCommand(env); + } + // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS + // env variable. + constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); + + // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. + String sparkHome = System.getenv("SPARK_HOME"); + env.put("R_PROFILE_USER", + join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); + + List args = new ArrayList(); + args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); + return args; + } + + private void constructEnvVarArgs( + Map env, + String submitArgsEnvVariable) throws IOException { + Properties props = loadPropertiesFile(); + mergeEnvPathList(env, getLibPathEnvName(), + firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props)); + + StringBuilder submitArgs = new StringBuilder(); + for (String arg : buildSparkSubmitArgs()) { + if (submitArgs.length() > 0) { + submitArgs.append(" "); + } + submitArgs.append(quoteForCommandString(arg)); + } + env.put(submitArgsEnvVariable, submitArgs.toString()); + } + + private boolean isClientMode(Properties userProps) { String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER)); // Default master is "local[*]", so assume client mode in that case. diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index dba0203867..1ae42eed8a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -79,9 +79,9 @@ public class CommandBuilderUtilsSuite { @Test public void testPythonArgQuoting() { - assertEquals("\"abc\"", quoteForPython("abc")); - assertEquals("\"a b c\"", quoteForPython("a b c")); - assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c")); + assertEquals("\"abc\"", quoteForCommandString("abc")); + assertEquals("\"a b c\"", quoteForCommandString("a b c")); + assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } private void testOpt(String opts, List expected) { diff --git a/pom.xml b/pom.xml index 42bd926a2f..70e297c4f0 100644 --- a/pom.xml +++ b/pom.xml @@ -1749,5 +1749,8 @@ parquet-provided + + sparkr + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a5e6b638d2..53ad67372e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType @Experimental class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { - private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala new file mode 100644 index 0000000000..d1ea7cc3e9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.r + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.r.SerDe +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} + +private[r] object SQLUtils { + def createSQLContext(jsc: JavaSparkContext): SQLContext = { + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) + } + + def toSeq[T](arr: Array[T]): Seq[T] = { + arr.toSeq + } + + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + + // A helper to include grouping columns in Agg() + def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { + val aggExprs = exprs.map { col => + col.expr match { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.simpleString)() + } + } + gd.toDF(aggExprs) + } + + def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { + df.map(r => rowToRBytes(r)) + } + + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = SerDe.readInt(dis) + Row.fromSeq((0 until num).map { i => + SerDe.readObject(dis) + }.toSeq) + } + + private[this] def rowToRBytes(row: Row): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, row.length) + (0 until row.length).map { idx => + val obj: Object = row(idx).asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def dfToCols(df: DataFrame): Array[Array[Byte]] = { + // localDF is Array[Row] + val localDF = df.collect() + val numCols = df.columns.length + // dfCols is Array[Array[Any]] + val dfCols = convertRowsToColumns(localDF, numCols) + + dfCols.map { col => + colToRBytes(col) + } + } + + def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + (0 until numCols).map { colIdx => + localDF.map { row => + row(colIdx) + } + }.toArray + } + + def colToRBytes(col: Array[Any]): Array[Byte] = { + val numRows = col.length + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + SerDe.writeInt(dos, numRows) + + col.map { item => + val obj: Object = item.asInstanceOf[Object] + SerDe.writeObject(dos, obj) + } + bos.toByteArray() + } + + def saveMode(mode: String): SaveMode = { + mode match { + case "append" => SaveMode.Append + case "overwrite" => SaveMode.Overwrite + case "error" => SaveMode.ErrorIfExists + case "ignore" => SaveMode.Ignore + } + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 24a1e02795..32bc4e5663 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -469,6 +469,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.submit.pyFiles", PythonRunner.formatPaths(args.pyFiles).mkString(",")) } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + // TODO(davies): add R dependencies here + } val mainMethod = userClassLoader.loadClass(args.userClass) .getMethod("main", classOf[Array[String]]) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3a..ae6dc1094d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var primaryPyFile: String = null + var primaryRFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 @@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--py-files") :: value :: tail => pyFiles = value args = tail @@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) { } } + if (primaryPyFile != null && primaryRFile != null) { + System.err.println("Cannot have primary-py-file and primary-r-file at the same time") + System.exit(-1) + } + userArgs = userArgsBuffer.readOnly } @@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file + | --primary-r-file A main R file | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7219852c0a..c1effd3c8a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -491,6 +491,12 @@ private[spark] class Client( } else { Nil } + val primaryRFile = + if (args.primaryRFile != null) { + Seq("--primary-r-file", args.primaryRFile) + } else { + Nil + } val amClass = if (isClusterMode) { Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName @@ -500,12 +506,15 @@ private[spark] class Client( if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs } + if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { + args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs + } val userArgs = args.userArgs.flatMap { arg => Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++ - Seq( + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3bc7eb1abf..da6798cb1b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userClass: String = null var pyFiles: String = null var primaryPyFile: String = null + var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() var executorMemory = 1024 // MB var executorCores = 1 @@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryPyFile = value args = tail + case ("--primary-r-file") :: value :: tail => + primaryRFile = value + args = tail + case ("--args" | "--arg") :: value :: tail => if (args(0) == "--args") { println("--args is deprecated. Use --arg instead.") @@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) throw new IllegalArgumentException(getUsageMessage(args)) } } + + if (primaryPyFile != null && primaryRFile != null) { + throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + + " at the same time") + } } private def getUsageMessage(unknownParam: List[String] = null): String = { @@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | mode) | --class CLASS_NAME Name of your application's main class (required) | --primary-py-file A main Python file + | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2)