diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8b5a7a9aef..5ed5070558 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -660,6 +660,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial with Logging { private var encryptionServer: PythonServer[Unit] = null + private var decryptionServer: PythonServer[Unit] = null /** * Read data from disks, then copy it to `out` @@ -708,16 +709,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial override def handleConnection(sock: Socket): Unit = { val env = SparkEnv.get val in = sock.getInputStream() - val dir = new File(Utils.getLocalDir(env.conf)) - val file = File.createTempFile("broadcast", "", dir) - path = file.getAbsolutePath - val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path)) + val abspath = new File(path).getAbsolutePath + val out = env.serializerManager.wrapForEncryption(new FileOutputStream(abspath)) DechunkedInputStream.dechunkAndCopyToOutput(in, out) } } Array(encryptionServer.port, encryptionServer.secret) } + def setupDecryptionServer(): Array[Any] = { + decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") { + override def handleConnection(sock: Socket): Unit = { + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream())) + Utils.tryWithSafeFinally { + val in = SparkEnv.get.serializerManager.wrapForEncryption(new FileInputStream(path)) + Utils.tryWithSafeFinally { + Utils.copyStream(in, out, false) + } { + in.close() + } + out.flush() + } { + JavaUtils.closeQuietly(out) + } + } + } + Array(decryptionServer.port, decryptionServer.secret) + } + + def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult() + def waitTillDataReceived(): Unit = encryptionServer.getResult() } // scalastyle:on no.finalize diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 1c7f2a7418..29358b5740 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -77,11 +77,12 @@ class Broadcast(object): # we're on the driver. We want the pickled data to end up in a file (maybe encrypted) f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) self._path = f.name - python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) + self._sc = sc + self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path) if sc._encryption_enabled: # with encryption, we ask the jvm to do the encryption for us, we send it data # over a socket - port, auth_secret = python_broadcast.setupEncryptionServer() + port, auth_secret = self._python_broadcast.setupEncryptionServer() (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret) broadcast_out = ChunkedStream(encryption_sock_file, 8192) else: @@ -89,12 +90,14 @@ class Broadcast(object): broadcast_out = f self.dump(value, broadcast_out) if sc._encryption_enabled: - python_broadcast.waitTillDataReceived() - self._jbroadcast = sc._jsc.broadcast(python_broadcast) + self._python_broadcast.waitTillDataReceived() + self._jbroadcast = sc._jsc.broadcast(self._python_broadcast) self._pickle_registry = pickle_registry else: # we're on an executor self._jbroadcast = None + self._sc = None + self._python_broadcast = None if sock_file is not None: # the jvm is doing decryption for us. Read the value # immediately from the sock_file @@ -134,7 +137,15 @@ class Broadcast(object): """ Return the broadcasted value """ if not hasattr(self, "_value") and self._path is not None: - self._value = self.load_from_path(self._path) + # we only need to decrypt it here when encryption is enabled and + # if its on the driver, since executor decryption is handled already + if self._sc is not None and self._sc._encryption_enabled: + port, auth_secret = self._python_broadcast.setupDecryptionServer() + (decrypted_sock_file, _) = local_connect_and_auth(port, auth_secret) + self._python_broadcast.waitTillBroadcastDataSent() + return self.load(decrypted_sock_file) + else: + self._value = self.load_from_path(self._path) return self._value def unpersist(self, blocking=False): diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index a98626e8f4..11d31d24bb 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -67,6 +67,21 @@ class BroadcastTest(unittest.TestCase): def test_broadcast_no_encryption(self): self._test_multiple_broadcasts() + def _test_broadcast_on_driver(self, *extra_confs): + conf = SparkConf() + for key, value in extra_confs: + conf.set(key, value) + conf.setMaster("local-cluster[2,1,1024]") + self.sc = SparkContext(conf=conf) + bs = self.sc.broadcast(value=5) + self.assertEqual(5, bs.value) + + def test_broadcast_value_driver_no_encryption(self): + self._test_broadcast_on_driver() + + def test_broadcast_value_driver_encryption(self): + self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) + class BroadcastFrameProtocolTest(unittest.TestCase):