[SPARK-27102][R][PYTHON][CORE] Remove the references to Python's Scala codes in R's Scala codes

## What changes were proposed in this pull request?

Currently, R's Scala codes happened to refer Python's Scala codes for code deduplications. It's a bit odd. For instance, when we face an exception from R, it shows python related code path, which makes confusing to debug. It should rather have one code base and R's and Python's should share.

This PR proposes:

1. Make a `SocketAuthServer` and move `PythonServer` so that `PythonRDD` and `RRDD` can share it.
2. Move `readRDDFromFile` and `readRDDFromInputStream` into `JavaRDD`.
3. Reuse `RAuthHelper` and remove `RSocketAuthHelper` in `RRDD`.
4. Rename `getEncryptionEnabled` to `isEncryptionEnabled` while I am here.

So, now, the places below:

- `sql/core/src/main/scala/org/apache/spark/sql/api/r`
- `core/src/main/scala/org/apache/spark/api/r`
- `mllib/src/main/scala/org/apache/spark/ml/r`

don't refer Python's Scala codes.

## How was this patch tested?

Existing tests should cover this.

Closes #24023 from HyukjinKwon/SPARK-27102.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Hyukjin Kwon 2019-03-10 15:08:23 +09:00
parent 470313e660
commit 28d003097b
11 changed files with 188 additions and 151 deletions

View file

@ -175,7 +175,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
if (objectSize < sizeLimit) {
jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
} else {
if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) {
if (callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)) {
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
# the length of slices here is the parallelism to use in the jvm's sc.parallelize()
parallelism <- as.integer(numSlices)

View file

@ -138,7 +138,7 @@ test_that("createDataFrame large objects", {
enableHiveSupport = FALSE))
sc <- getSparkContext()
actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)
actual <- callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)
expected <- as.logical(encryptionEnabled)
expect_equal(actual, expected)

View file

@ -17,6 +17,9 @@
package org.apache.spark.api.java
import java.io.{DataInputStream, EOFException, FileInputStream, InputStream}
import scala.collection.mutable
import scala.language.implicitConversions
import scala.reflect.ClassTag
@ -213,4 +216,34 @@ object JavaRDD {
implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
private[api] def readRDDFromFile(
sc: JavaSparkContext,
filename: String,
parallelism: Int): JavaRDD[Array[Byte]] = {
readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
}
private[api] def readRDDFromInputStream(
sc: SparkContext,
in: InputStream,
parallelism: Int): JavaRDD[Array[Byte]] = {
val din = new DataInputStream(in)
try {
val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = din.readInt()
val obj = new Array[Byte](length)
din.readFully(obj)
objs += obj
}
} catch {
case eof: EOFException => // No-op
}
JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
} finally {
din.close()
}
}
}

View file

@ -42,7 +42,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.util._
@ -171,32 +171,18 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
def readRDDFromFile(
sc: JavaSparkContext,
filename: String,
parallelism: Int): JavaRDD[Array[Byte]] = {
JavaRDD.readRDDFromFile(sc, filename, parallelism)
}
def readRDDFromInputStream(
sc: SparkContext,
in: InputStream,
parallelism: Int): JavaRDD[Array[Byte]] = {
val din = new DataInputStream(in)
try {
val objs = new mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = din.readInt()
val obj = new Array[Byte](length)
din.readFully(obj)
objs += obj
}
} catch {
case eof: EOFException => // No-op
}
JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
} finally {
din.close()
}
JavaRDD.readRDDFromInputStream(sc, in, parallelism)
}
def setupBroadcast(path: String): PythonBroadcast = {
@ -430,21 +416,7 @@ private[spark] object PythonRDD extends Logging {
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
serveToStream(threadName, authHelper)(writeFunc)
}
private[spark] def serveToStream(
threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
: Array[Any] = {
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}
Array(port, secret)
SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@ -666,8 +638,8 @@ private[spark] class PythonAccumulatorV2(
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
with Logging {
private var encryptionServer: PythonServer[Unit] = null
private var decryptionServer: PythonServer[Unit] = null
private var encryptionServer: SocketAuthServer[Unit] = null
private var decryptionServer: SocketAuthServer[Unit] = null
/**
* Read data from disks, then copy it to `out`
@ -712,7 +684,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
def setupEncryptionServer(): Array[Any] = {
encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
override def handleConnection(sock: Socket): Unit = {
val env = SparkEnv.get
val in = sock.getInputStream()
@ -725,7 +697,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
def setupDecryptionServer(): Array[Any] = {
decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
decryptionServer = new SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
override def handleConnection(sock: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
Utils.tryWithSafeFinally {
@ -820,91 +792,13 @@ private[spark] object DechunkedInputStream {
}
}
/**
* Creates a server in the jvm to communicate with python for handling one batch of data, with
* authentication and error handling.
*/
private[spark] abstract class PythonServer[T](
authHelper: SocketAuthHelper,
threadName: String) {
def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
def this(threadName: String) = this(SparkEnv.get, threadName)
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
promise.complete(Try(handleConnection(sock)))
}
/**
* Handle a connection which has already been authenticated. Any error from this function
* will clean up this connection and the entire server, and get propogated to [[getResult]].
*/
def handleConnection(sock: Socket): T
val promise = Promise[T]()
/**
* Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
* handleConnection throws an exception, this will throw an exception which includes the original
* exception as a cause.
*/
def getResult(): T = {
getResult(Duration.Inf)
}
def getResult(wait: Duration): T = {
ThreadUtils.awaitResult(promise.future, wait)
}
}
private[spark] object PythonServer {
/**
* Create a socket server and run user function on the socket in a background thread.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*
* The thread will terminate after the supplied user function, or if there are any exceptions.
*
* If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
*
* @return The port number of a local socket and the secret for authentication.
*/
def setupOneConnectionServer(
authHelper: SocketAuthHelper,
threadName: String)
(func: Socket => Unit): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
func(sock)
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}
}
/**
* Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol.
*/
private[spark] class EncryptedPythonBroadcastServer(
val env: SparkEnv,
val idsAndFiles: Seq[(Long, String)])
extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {
override def handleConnection(socket: Socket): Unit = {
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
@ -942,7 +836,7 @@ private[spark] class EncryptedPythonBroadcastServer(
* over a socket. This is used in preference to writing data to a file when encryption is enabled.
*/
private[spark] abstract class PythonRDDServer
extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
extends SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
val in = sock.getInputStream()
@ -961,4 +855,3 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
PythonRDD.readRDDFromInputStream(sc, input, parallelism)
}
}

View file

@ -76,7 +76,7 @@ private[spark] object PythonUtils {
jm.asScala.toMap
}
def getEncryptionEnabled(sc: JavaSparkContext): Boolean = {
def isEncryptionEnabled(sc: JavaSparkContext): Boolean = {
sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED)
}
}

View file

@ -17,9 +17,8 @@
package org.apache.spark.api.r
import java.io.{DataInputStream, File, OutputStream}
import java.io.{File, OutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Map => JMap}
import scala.collection.JavaConverters._
@ -27,11 +26,10 @@ import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.{PythonRDD, PythonServer}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
parent: RDD[T],
@ -163,12 +161,12 @@ private[spark] object RRDD {
*/
def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
JavaRDD.readRDDFromFile(jsc, fileName, parallelism)
}
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
}
}
@ -177,23 +175,11 @@ private[spark] object RRDD {
* over a socket. This is used in preference to writing data to a file when encryption is enabled.
*/
private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int)
extends PythonServer[JavaRDD[Array[Byte]]](
new RSocketAuthHelper(), "sparkr-parallelize-server") {
extends SocketAuthServer[JavaRDD[Array[Byte]]](
new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") {
override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
val in = sock.getInputStream()
PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism)
}
}
private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) {
override protected def readUtf8(s: Socket): String = {
val din = new DataInputStream(s.getInputStream())
val len = din.readInt()
val bytes = new Array[Byte](len)
din.readFully(bytes)
// The R code adds a null terminator to serialized strings, so ignore it here.
assert(bytes(bytes.length - 1) == 0) // sanity check.
new String(bytes, 0, bytes.length - 1, UTF_8)
JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism)
}
}

View file

@ -22,7 +22,6 @@ import java.util.Arrays
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.internal.config._
private[spark] object RUtils {
@ -108,5 +107,7 @@ private[spark] object RUtils {
}
}
def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc)
def isEncryptionEnabled(sc: JavaSparkContext): Boolean = {
sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED)
}
}

View file

@ -17,7 +17,7 @@
package org.apache.spark.security
import java.io.{DataInputStream, DataOutputStream, InputStream}
import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8
@ -115,3 +115,19 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
}
}
private[spark] object SocketAuthHelper {
def serveToStream(
threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}
Array(port, secret)
}
}

View file

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.security
import java.net.{InetAddress, ServerSocket, Socket}
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
import scala.language.existentials
import scala.util.Try
import org.apache.spark.SparkEnv
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.ThreadUtils
/**
* Creates a server in the JVM to communicate with external processes (e.g., Python and R) for
* handling one batch of data, with authentication and error handling.
*/
private[spark] abstract class SocketAuthServer[T](
authHelper: SocketAuthHelper,
threadName: String) {
def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
def this(threadName: String) = this(SparkEnv.get, threadName)
private val promise = Promise[T]()
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock =>
promise.complete(Try(handleConnection(sock)))
}
/**
* Handle a connection which has already been authenticated. Any error from this function
* will clean up this connection and the entire server, and get propagated to [[getResult]].
*/
def handleConnection(sock: Socket): T
/**
* Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
* handleConnection throws an exception, this will throw an exception which includes the original
* exception as a cause.
*/
def getResult(): T = {
getResult(Duration.Inf)
}
def getResult(wait: Duration): T = {
ThreadUtils.awaitResult(promise.future, wait)
}
}
private[spark] object SocketAuthServer {
/**
* Create a socket server and run user function on the socket in a background thread.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*
* The thread will terminate after the supplied user function, or if there are any exceptions.
*
* If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]]
*
* @return The port number of a local socket and the secret for authentication.
*/
def setupOneConnectionServer(
authHelper: SocketAuthHelper,
threadName: String)
(func: Socket => Unit): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
func(sock)
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}
}

View file

@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets
import scala.concurrent.duration.Duration
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
class PythonRDDSuite extends SparkFunSuite {
@ -59,7 +59,7 @@ class PythonRDDSuite extends SparkFunSuite {
}
class ExceptionPythonServer(authHelper: SocketAuthHelper)
extends PythonServer[Unit](authHelper, "error-server") {
extends SocketAuthServer[Unit](authHelper, "error-server") {
override def handleConnection(sock: Socket): Unit = {
throw new Exception("exception within handleConnection")

View file

@ -204,7 +204,7 @@ class SparkContext(object):
# If encryption is enabled, we need to setup a server in the jvm to read broadcast
# data via a socket.
# scala's mangled names w/ $ in them require special treatment.
self._encryption_enabled = self._jvm.PythonUtils.getEncryptionEnabled(self._jsc)
self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]