[SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC

## What changes were proposed in this pull request?

Currently, PythonBroadcast may delete its data file while a python worker still needs it. This happens because PythonBroadcast overrides the `finalize()` method to delete its data file. So, when GC happens and no  references on broadcast variable, it may trigger `finalize()` to delete
data file. That's also means, data under python Broadcast variable couldn't be deleted when `unpersist()`/`destroy()` called but relys on GC.

In this PR, we removed the `finalize()` method, and map the PythonBroadcast data file to a BroadcastBlock(which has the same broadcast id with the broadcast variable who wrapped this PythonBroadcast) when PythonBroadcast is deserializing. As a result, the data file could be deleted just like other pieces of the Broadcast variable when `unpersist()`/`destroy()` called and do not rely on GC any more.

## How was this patch tested?

Added a Python test, and tested manually(verified create/delete the broadcast block).

Closes #25262 from Ngone51/SPARK-28486.

Authored-by: wuyi <ngone_5451@163.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
wuyi 2019-08-05 20:18:53 +09:00 committed by HyukjinKwon
parent cae500a255
commit 94499af6f0
5 changed files with 67 additions and 28 deletions

View file

@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@ -39,6 +40,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util._
@ -697,10 +699,11 @@ private[spark] class PythonAccumulatorV2(
}
}
// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
with Logging {
// id of the Broadcast variable which wrapped this PythonBroadcast
private var broadcastId: Long = _
private var encryptionServer: SocketAuthServer[Unit] = null
private var decryptionServer: SocketAuthServer[Unit] = null
@ -708,6 +711,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
* Read data from disks, then copy it to `out`
*/
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.writeLong(broadcastId)
val in = new FileInputStream(new File(path))
try {
Utils.copyStream(in, out)
@ -717,33 +721,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
/**
* Write data into disk, using randomly generated name.
* Write data into disk and map it to a broadcast block.
*/
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
val file = File.createTempFile("broadcast", "", dir)
path = file.getAbsolutePath
val out = new FileOutputStream(file)
Utils.tryWithSafeFinally {
Utils.copyStream(in, out)
} {
out.close()
}
}
/**
* Delete the file once the object is GCed.
*/
override def finalize() {
if (!path.isEmpty) {
val file = new File(path)
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file.getPath}")
private def readObject(in: ObjectInputStream): Unit = {
broadcastId = in.readLong()
val blockId = BroadcastBlockId(broadcastId, "python")
val blockManager = SparkEnv.get.blockManager
val diskBlockManager = blockManager.diskBlockManager
if (!diskBlockManager.containsBlock(blockId)) {
Utils.tryOrIOException {
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
val file = File.createTempFile("broadcast", "", dir)
val out = new FileOutputStream(file)
Utils.tryWithSafeFinally {
val size = Utils.copyStream(in, out)
val ct = implicitly[ClassTag[Object]]
// SPARK-28486: map broadcast file to a broadcast block, so that it could be
// cleared by unpersist/destroy rather than gc(previously).
val blockStoreUpdater = blockManager.
TempFileBasedBlockStoreUpdater(blockId, StorageLevel.DISK_ONLY, ct, file, size)
blockStoreUpdater.save()
} {
out.close()
}
}
}
super.finalize()
path = diskBlockManager.getFile(blockId).getAbsolutePath
}
def setBroadcastId(bid: Long): Unit = {
this.broadcastId = bid
}
def setupEncryptionServer(): Array[Any] = {
@ -783,7 +790,6 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
def waitTillDataReceived(): Unit = encryptionServer.getResult()
}
// scalastyle:on no.finalize
/**
* The inverse of pyspark's ChunkedStream for sending data of unknown size.

View file

@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.internal.Logging
private[spark] class BroadcastManager(
@ -59,7 +60,18 @@ private[spark] class BroadcastManager(
}
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
val bid = nextBroadcastId.getAndIncrement()
value_ match {
case pb: PythonBroadcast =>
// SPARK-28486: attach this new broadcast variable's id to the PythonBroadcast,
// so that underlying data file of PythonBroadcast could be mapped to the
// BroadcastBlockId according to this id. Please see the specific usage of the
// id in PythonBroadcast.readObject().
pb.setBroadcastId(bid)
case _ => // do nothing
}
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
}
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {

View file

@ -211,7 +211,7 @@ private[spark] class BlockManager(
*
* @param blockSize the decrypted size of the block
*/
private abstract class BlockStoreUpdater[T](
private[spark] abstract class BlockStoreUpdater[T](
blockSize: Long,
blockId: BlockId,
level: StorageLevel,
@ -357,7 +357,7 @@ private[spark] class BlockManager(
/**
* Helper for storing a block based from bytes already in a local temp file.
*/
private case class TempFileBasedBlockStoreUpdater[T](
private[spark] case class TempFileBasedBlockStoreUpdater[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],

View file

@ -36,6 +36,9 @@ object MimaExcludes {
// Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq(
// [SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC
ProblemFilters.exclude[InaccessibleMethodProblem]("java.lang.Object.finalize"),
// [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"),

View file

@ -16,6 +16,7 @@
#
import os
import random
import time
import tempfile
import unittest
@ -82,6 +83,23 @@ class BroadcastTest(unittest.TestCase):
def test_broadcast_value_driver_encryption(self):
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
def test_broadcast_value_against_gc(self):
# Test broadcast value against gc.
conf = SparkConf()
conf.setMaster("local[1,1]")
conf.set("spark.memory.fraction", "0.0001")
self.sc = SparkContext(conf=conf)
b = self.sc.broadcast([100])
try:
res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
self.assertEqual([0], res)
self.sc._jvm.java.lang.System.gc()
time.sleep(5)
res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
self.assertEqual([100], res)
finally:
b.destroy()
class BroadcastFrameProtocolTest(unittest.TestCase):