[SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC

## What changes were proposed in this pull request?

Currently, PythonBroadcast may delete its data file while a python worker still needs it. This happens because PythonBroadcast overrides the `finalize()` method to delete its data file. So, when GC happens and no  references on broadcast variable, it may trigger `finalize()` to delete
data file. That's also means, data under python Broadcast variable couldn't be deleted when `unpersist()`/`destroy()` called but relys on GC.

In this PR, we removed the `finalize()` method, and map the PythonBroadcast data file to a BroadcastBlock(which has the same broadcast id with the broadcast variable who wrapped this PythonBroadcast) when PythonBroadcast is deserializing. As a result, the data file could be deleted just like other pieces of the Broadcast variable when `unpersist()`/`destroy()` called and do not rely on GC any more.

## How was this patch tested?

Added a Python test, and tested manually(verified create/delete the broadcast block).

Closes #25262 from Ngone51/SPARK-28486.

Authored-by: wuyi <ngone_5451@163.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
wuyi 2019-08-05 20:18:53 +09:00 committed by HyukjinKwon
parent cae500a255
commit 94499af6f0
5 changed files with 67 additions and 28 deletions

View file

@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.compress.CompressionCodec
@ -39,6 +40,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer} import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util._ import org.apache.spark.util._
@ -697,10 +699,11 @@ private[spark] class PythonAccumulatorV2(
} }
} }
// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
with Logging { with Logging {
// id of the Broadcast variable which wrapped this PythonBroadcast
private var broadcastId: Long = _
private var encryptionServer: SocketAuthServer[Unit] = null private var encryptionServer: SocketAuthServer[Unit] = null
private var decryptionServer: SocketAuthServer[Unit] = null private var decryptionServer: SocketAuthServer[Unit] = null
@ -708,6 +711,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
* Read data from disks, then copy it to `out` * Read data from disks, then copy it to `out`
*/ */
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.writeLong(broadcastId)
val in = new FileInputStream(new File(path)) val in = new FileInputStream(new File(path))
try { try {
Utils.copyStream(in, out) Utils.copyStream(in, out)
@ -717,33 +721,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
} }
/** /**
* Write data into disk, using randomly generated name. * Write data into disk and map it to a broadcast block.
*/ */
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { private def readObject(in: ObjectInputStream): Unit = {
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf)) broadcastId = in.readLong()
val file = File.createTempFile("broadcast", "", dir) val blockId = BroadcastBlockId(broadcastId, "python")
path = file.getAbsolutePath val blockManager = SparkEnv.get.blockManager
val out = new FileOutputStream(file) val diskBlockManager = blockManager.diskBlockManager
Utils.tryWithSafeFinally { if (!diskBlockManager.containsBlock(blockId)) {
Utils.copyStream(in, out) Utils.tryOrIOException {
} { val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
out.close() val file = File.createTempFile("broadcast", "", dir)
} val out = new FileOutputStream(file)
} Utils.tryWithSafeFinally {
val size = Utils.copyStream(in, out)
/** val ct = implicitly[ClassTag[Object]]
* Delete the file once the object is GCed. // SPARK-28486: map broadcast file to a broadcast block, so that it could be
*/ // cleared by unpersist/destroy rather than gc(previously).
override def finalize() { val blockStoreUpdater = blockManager.
if (!path.isEmpty) { TempFileBasedBlockStoreUpdater(blockId, StorageLevel.DISK_ONLY, ct, file, size)
val file = new File(path) blockStoreUpdater.save()
if (file.exists()) { } {
if (!file.delete()) { out.close()
logWarning(s"Error deleting ${file.getPath}")
} }
} }
} }
super.finalize() path = diskBlockManager.getFile(blockId).getAbsolutePath
}
def setBroadcastId(bid: Long): Unit = {
this.broadcastId = bid
} }
def setupEncryptionServer(): Array[Any] = { def setupEncryptionServer(): Array[Any] = {
@ -783,7 +790,6 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
def waitTillDataReceived(): Unit = encryptionServer.getResult() def waitTillDataReceived(): Unit = encryptionServer.getResult()
} }
// scalastyle:on no.finalize
/** /**
* The inverse of pyspark's ChunkedStream for sending data of unknown size. * The inverse of pyspark's ChunkedStream for sending data of unknown size.

View file

@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}
import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
private[spark] class BroadcastManager( private[spark] class BroadcastManager(
@ -59,7 +60,18 @@ private[spark] class BroadcastManager(
} }
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) val bid = nextBroadcastId.getAndIncrement()
value_ match {
case pb: PythonBroadcast =>
// SPARK-28486: attach this new broadcast variable's id to the PythonBroadcast,
// so that underlying data file of PythonBroadcast could be mapped to the
// BroadcastBlockId according to this id. Please see the specific usage of the
// id in PythonBroadcast.readObject().
pb.setBroadcastId(bid)
case _ => // do nothing
}
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
} }
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {

View file

@ -211,7 +211,7 @@ private[spark] class BlockManager(
* *
* @param blockSize the decrypted size of the block * @param blockSize the decrypted size of the block
*/ */
private abstract class BlockStoreUpdater[T]( private[spark] abstract class BlockStoreUpdater[T](
blockSize: Long, blockSize: Long,
blockId: BlockId, blockId: BlockId,
level: StorageLevel, level: StorageLevel,
@ -357,7 +357,7 @@ private[spark] class BlockManager(
/** /**
* Helper for storing a block based from bytes already in a local temp file. * Helper for storing a block based from bytes already in a local temp file.
*/ */
private case class TempFileBasedBlockStoreUpdater[T]( private[spark] case class TempFileBasedBlockStoreUpdater[T](
blockId: BlockId, blockId: BlockId,
level: StorageLevel, level: StorageLevel,
classTag: ClassTag[T], classTag: ClassTag[T],

View file

@ -36,6 +36,9 @@ object MimaExcludes {
// Exclude rules for 3.0.x // Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq( lazy val v30excludes = v24excludes ++ Seq(
// [SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC
ProblemFilters.exclude[InaccessibleMethodProblem]("java.lang.Object.finalize"),
// [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling // [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"),

View file

@ -16,6 +16,7 @@
# #
import os import os
import random import random
import time
import tempfile import tempfile
import unittest import unittest
@ -82,6 +83,23 @@ class BroadcastTest(unittest.TestCase):
def test_broadcast_value_driver_encryption(self): def test_broadcast_value_driver_encryption(self):
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true")) self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
def test_broadcast_value_against_gc(self):
# Test broadcast value against gc.
conf = SparkConf()
conf.setMaster("local[1,1]")
conf.set("spark.memory.fraction", "0.0001")
self.sc = SparkContext(conf=conf)
b = self.sc.broadcast([100])
try:
res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
self.assertEqual([0], res)
self.sc._jvm.java.lang.System.gc()
time.sleep(5)
res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
self.assertEqual([100], res)
finally:
b.destroy()
class BroadcastFrameProtocolTest(unittest.TestCase): class BroadcastFrameProtocolTest(unittest.TestCase):