diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1c064a63ef..619153645d 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -175,7 +175,7 @@ parallelize <- function(sc, coll, numSlices = 1) { if (objectSize < sizeLimit) { jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) } else { - if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) { + if (callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)) { connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) # the length of slices here is the parallelism to use in the jvm's sc.parallelize() parallelism <- as.integer(numSlices) diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 1525bdb2f5..e01f6ee005 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -138,7 +138,7 @@ test_that("createDataFrame large objects", { enableHiveSupport = FALSE)) sc <- getSparkContext() - actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc) + actual <- callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc) expected <- as.logical(encryptionEnabled) expect_equal(actual, expected) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 41b5cab601..6f0182255e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,6 +17,9 @@ package org.apache.spark.api.java +import java.io.{DataInputStream, EOFException, FileInputStream, InputStream} + +import scala.collection.mutable import scala.language.implicitConversions import scala.reflect.ClassTag @@ -213,4 +216,34 @@ object JavaRDD { implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd + + private[api] def readRDDFromFile( + sc: JavaSparkContext, + filename: String, + parallelism: Int): JavaRDD[Array[Byte]] = { + readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism) + } + + private[api] def readRDDFromInputStream( + sc: SparkContext, + in: InputStream, + parallelism: Int): JavaRDD[Array[Byte]] = { + val din = new DataInputStream(in) + try { + val objs = new mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = din.readInt() + val obj = new Array[Byte](length) + din.readFully(obj) + objs += obj + } + } catch { + case eof: EOFException => // No-op + } + JavaRDD.fromRDD(sc.parallelize(objs, parallelism)) + } finally { + din.close() + } + } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0937a63dad..5b492b1f39 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} import org.apache.spark.util._ @@ -171,32 +171,18 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } - def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): - JavaRDD[Array[Byte]] = { - readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism) + def readRDDFromFile( + sc: JavaSparkContext, + filename: String, + parallelism: Int): JavaRDD[Array[Byte]] = { + JavaRDD.readRDDFromFile(sc, filename, parallelism) } def readRDDFromInputStream( sc: SparkContext, in: InputStream, parallelism: Int): JavaRDD[Array[Byte]] = { - val din = new DataInputStream(in) - try { - val objs = new mutable.ArrayBuffer[Array[Byte]] - try { - while (true) { - val length = din.readInt() - val obj = new Array[Byte](length) - din.readFully(obj) - objs += obj - } - } catch { - case eof: EOFException => // No-op - } - JavaRDD.fromRDD(sc.parallelize(objs, parallelism)) - } finally { - din.close() - } + JavaRDD.readRDDFromInputStream(sc, in, parallelism) } def setupBroadcast(path: String): PythonBroadcast = { @@ -430,21 +416,7 @@ private[spark] object PythonRDD extends Logging { */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { - serveToStream(threadName, authHelper)(writeFunc) - } - - private[spark] def serveToStream( - threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit) - : Array[Any] = { - val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => - val out = new BufferedOutputStream(s.getOutputStream()) - Utils.tryWithSafeFinally { - writeFunc(out) - } { - out.close() - } - } - Array(port, secret) + SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], @@ -666,8 +638,8 @@ private[spark] class PythonAccumulatorV2( private[spark] class PythonBroadcast(@transient var path: String) extends Serializable with Logging { - private var encryptionServer: PythonServer[Unit] = null - private var decryptionServer: PythonServer[Unit] = null + private var encryptionServer: SocketAuthServer[Unit] = null + private var decryptionServer: SocketAuthServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -712,7 +684,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } def setupEncryptionServer(): Array[Any] = { - encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") { + encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") { override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() @@ -725,7 +697,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } def setupDecryptionServer(): Array[Any] = { - decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + decryptionServer = new SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") { override def handleConnection(sock: Socket): Unit = { val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) Utils.tryWithSafeFinally { @@ -820,91 +792,13 @@ private[spark] object DechunkedInputStream { } } -/** - * Creates a server in the jvm to communicate with python for handling one batch of data, with - * authentication and error handling. - */ -private[spark] abstract class PythonServer[T]( - authHelper: SocketAuthHelper, - threadName: String) { - - def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) - def this(threadName: String) = this(SparkEnv.get, threadName) - - val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock => - promise.complete(Try(handleConnection(sock))) - } - - /** - * Handle a connection which has already been authenticated. Any error from this function - * will clean up this connection and the entire server, and get propogated to [[getResult]]. - */ - def handleConnection(sock: Socket): T - - val promise = Promise[T]() - - /** - * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If - * handleConnection throws an exception, this will throw an exception which includes the original - * exception as a cause. - */ - def getResult(): T = { - getResult(Duration.Inf) - } - - def getResult(wait: Duration): T = { - ThreadUtils.awaitResult(promise.future, wait) - } - -} - -private[spark] object PythonServer { - - /** - * Create a socket server and run user function on the socket in a background thread. - * - * The socket server can only accept one connection, or close if no connection - * in 15 seconds. - * - * The thread will terminate after the supplied user function, or if there are any exceptions. - * - * If you need to get a result of the supplied function, create a subclass of [[PythonServer]] - * - * @return The port number of a local socket and the secret for authentication. - */ - def setupOneConnectionServer( - authHelper: SocketAuthHelper, - threadName: String) - (func: Socket => Unit): (Int, String) = { - val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) - - new Thread(threadName) { - setDaemon(true) - override def run(): Unit = { - var sock: Socket = null - try { - sock = serverSocket.accept() - authHelper.authClient(sock) - func(sock) - } finally { - JavaUtils.closeQuietly(serverSocket) - JavaUtils.closeQuietly(sock) - } - } - }.start() - (serverSocket.getLocalPort, authHelper.secret) - } -} - /** * Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol. */ private[spark] class EncryptedPythonBroadcastServer( val env: SparkEnv, val idsAndFiles: Seq[(Long, String)]) - extends PythonServer[Unit]("broadcast-decrypt-server") with Logging { + extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging { override def handleConnection(socket: Socket): Unit = { val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream())) @@ -942,7 +836,7 @@ private[spark] class EncryptedPythonBroadcastServer( * over a socket. This is used in preference to writing data to a file when encryption is enabled. */ private[spark] abstract class PythonRDDServer - extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") { + extends SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") { def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { val in = sock.getInputStream() @@ -961,4 +855,3 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int) PythonRDD.readRDDFromInputStream(sc, input, parallelism) } } - diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index b6b0cac910..ab1bf69e27 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -76,7 +76,7 @@ private[spark] object PythonUtils { jm.asScala.toMap } - def getEncryptionEnabled(sc: JavaSparkContext): Boolean = { + def isEncryptionEnabled(sc: JavaSparkContext): Boolean = { sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 04fc6e18c1..4a59c3e209 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,9 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, File, OutputStream} +import java.io.{File, OutputStream} import java.net.Socket -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -27,11 +26,10 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.{PythonRDD, PythonServer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -163,12 +161,12 @@ private[spark] object RRDD { */ def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int): JavaRDD[Array[Byte]] = { - PythonRDD.readRDDFromFile(jsc, fileName, parallelism) + JavaRDD.readRDDFromFile(jsc, fileName, parallelism) } private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { - PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc) + SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc) } } @@ -177,23 +175,11 @@ private[spark] object RRDD { * over a socket. This is used in preference to writing data to a file when encryption is enabled. */ private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int) - extends PythonServer[JavaRDD[Array[Byte]]]( - new RSocketAuthHelper(), "sparkr-parallelize-server") { + extends SocketAuthServer[JavaRDD[Array[Byte]]]( + new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") { override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = { val in = sock.getInputStream() - PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism) - } -} - -private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) { - override protected def readUtf8(s: Socket): String = { - val din = new DataInputStream(s.getInputStream()) - val len = din.readInt() - val bytes = new Array[Byte](len) - din.readFully(bytes) - // The R code adds a null terminator to serialized strings, so ignore it here. - assert(bytes(bytes.length - 1) == 0) // sanity check. - new String(bytes, 0, bytes.length - 1, UTF_8) + JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 6832223a5d..5a433022ac 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -22,7 +22,6 @@ import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.api.python.PythonUtils import org.apache.spark.internal.config._ private[spark] object RUtils { @@ -108,5 +107,7 @@ private[spark] object RUtils { } } - def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc) + def isEncryptionEnabled(sc: JavaSparkContext): Boolean = { + sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED) + } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index ea38ccb289..3a107c0764 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -17,7 +17,7 @@ package org.apache.spark.security -import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream} import java.net.Socket import java.nio.charset.StandardCharsets.UTF_8 @@ -115,3 +115,19 @@ private[spark] class SocketAuthHelper(conf: SparkConf) { } } + +private[spark] object SocketAuthHelper { + def serveToStream( + threadName: String, + authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = { + val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s => + val out = new BufferedOutputStream(s.getOutputStream()) + Utils.tryWithSafeFinally { + writeFunc(out) + } { + out.close() + } + } + Array(port, secret) + } +} diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala new file mode 100644 index 0000000000..c65c8fd6e3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.net.{InetAddress, ServerSocket, Socket} + +import scala.concurrent.Promise +import scala.concurrent.duration.Duration +import scala.language.existentials +import scala.util.Try + +import org.apache.spark.SparkEnv +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.ThreadUtils + + +/** + * Creates a server in the JVM to communicate with external processes (e.g., Python and R) for + * handling one batch of data, with authentication and error handling. + */ +private[spark] abstract class SocketAuthServer[T]( + authHelper: SocketAuthHelper, + threadName: String) { + + def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName) + def this(threadName: String) = this(SparkEnv.get, threadName) + + private val promise = Promise[T]() + + val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock => + promise.complete(Try(handleConnection(sock))) + } + + /** + * Handle a connection which has already been authenticated. Any error from this function + * will clean up this connection and the entire server, and get propagated to [[getResult]]. + */ + def handleConnection(sock: Socket): T + + /** + * Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If + * handleConnection throws an exception, this will throw an exception which includes the original + * exception as a cause. + */ + def getResult(): T = { + getResult(Duration.Inf) + } + + def getResult(wait: Duration): T = { + ThreadUtils.awaitResult(promise.future, wait) + } + +} + +private[spark] object SocketAuthServer { + + /** + * Create a socket server and run user function on the socket in a background thread. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * The thread will terminate after the supplied user function, or if there are any exceptions. + * + * If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]] + * + * @return The port number of a local socket and the secret for authentication. + */ + def setupOneConnectionServer( + authHelper: SocketAuthHelper, + threadName: String) + (func: Socket => Unit): (Int, String) = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + // Close the socket if no connection in 15 seconds + serverSocket.setSoTimeout(15000) + + new Thread(threadName) { + setDaemon(true) + override def run(): Unit = { + var sock: Socket = null + try { + sock = serverSocket.accept() + authHelper.authClient(sock) + func(sock) + } finally { + JavaUtils.closeQuietly(serverSocket) + JavaUtils.closeQuietly(sock) + } + } + }.start() + (serverSocket.getLocalPort, authHelper.secret) + } +} diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 6f9b583898..e2ec50fb1f 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets import scala.concurrent.duration.Duration import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} class PythonRDDSuite extends SparkFunSuite { @@ -59,7 +59,7 @@ class PythonRDDSuite extends SparkFunSuite { } class ExceptionPythonServer(authHelper: SocketAuthHelper) - extends PythonServer[Unit](authHelper, "error-server") { + extends SocketAuthServer[Unit](authHelper, "error-server") { override def handleConnection(sock: Socket): Unit = { throw new Exception("exception within handleConnection") diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5a4bd574c7..63c3043a94 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -204,7 +204,7 @@ class SparkContext(object): # If encryption is enabled, we need to setup a server in the jvm to read broadcast # data via a socket. # scala's mangled names w/ $ in them require special treatment. - self._encryption_enabled = self._jvm.PythonUtils.getEncryptionEnabled(self._jsc) + self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') self.pythonVer = "%d.%d" % sys.version_info[:2]