[SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC
## What changes were proposed in this pull request? Currently, PythonBroadcast may delete its data file while a python worker still needs it. This happens because PythonBroadcast overrides the `finalize()` method to delete its data file. So, when GC happens and no references on broadcast variable, it may trigger `finalize()` to delete data file. That's also means, data under python Broadcast variable couldn't be deleted when `unpersist()`/`destroy()` called but relys on GC. In this PR, we removed the `finalize()` method, and map the PythonBroadcast data file to a BroadcastBlock(which has the same broadcast id with the broadcast variable who wrapped this PythonBroadcast) when PythonBroadcast is deserializing. As a result, the data file could be deleted just like other pieces of the Broadcast variable when `unpersist()`/`destroy()` called and do not rely on GC any more. ## How was this patch tested? Added a Python test, and tested manually(verified create/delete the broadcast block). Closes #25262 from Ngone51/SPARK-28486. Authored-by: wuyi <ngone_5451@163.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
cae500a255
commit
94499af6f0
|
@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
|
|||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
|
@ -39,6 +40,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE
|
|||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
|
||||
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
|
||||
import org.apache.spark.util._
|
||||
|
||||
|
||||
|
@ -697,10 +699,11 @@ private[spark] class PythonAccumulatorV2(
|
|||
}
|
||||
}
|
||||
|
||||
// scalastyle:off no.finalize
|
||||
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
|
||||
with Logging {
|
||||
|
||||
// id of the Broadcast variable which wrapped this PythonBroadcast
|
||||
private var broadcastId: Long = _
|
||||
private var encryptionServer: SocketAuthServer[Unit] = null
|
||||
private var decryptionServer: SocketAuthServer[Unit] = null
|
||||
|
||||
|
@ -708,6 +711,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
|
|||
* Read data from disks, then copy it to `out`
|
||||
*/
|
||||
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
|
||||
out.writeLong(broadcastId)
|
||||
val in = new FileInputStream(new File(path))
|
||||
try {
|
||||
Utils.copyStream(in, out)
|
||||
|
@ -717,33 +721,36 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
|
|||
}
|
||||
|
||||
/**
|
||||
* Write data into disk, using randomly generated name.
|
||||
* Write data into disk and map it to a broadcast block.
|
||||
*/
|
||||
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
|
||||
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
|
||||
val file = File.createTempFile("broadcast", "", dir)
|
||||
path = file.getAbsolutePath
|
||||
val out = new FileOutputStream(file)
|
||||
Utils.tryWithSafeFinally {
|
||||
Utils.copyStream(in, out)
|
||||
} {
|
||||
out.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the file once the object is GCed.
|
||||
*/
|
||||
override def finalize() {
|
||||
if (!path.isEmpty) {
|
||||
val file = new File(path)
|
||||
if (file.exists()) {
|
||||
if (!file.delete()) {
|
||||
logWarning(s"Error deleting ${file.getPath}")
|
||||
private def readObject(in: ObjectInputStream): Unit = {
|
||||
broadcastId = in.readLong()
|
||||
val blockId = BroadcastBlockId(broadcastId, "python")
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
val diskBlockManager = blockManager.diskBlockManager
|
||||
if (!diskBlockManager.containsBlock(blockId)) {
|
||||
Utils.tryOrIOException {
|
||||
val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
|
||||
val file = File.createTempFile("broadcast", "", dir)
|
||||
val out = new FileOutputStream(file)
|
||||
Utils.tryWithSafeFinally {
|
||||
val size = Utils.copyStream(in, out)
|
||||
val ct = implicitly[ClassTag[Object]]
|
||||
// SPARK-28486: map broadcast file to a broadcast block, so that it could be
|
||||
// cleared by unpersist/destroy rather than gc(previously).
|
||||
val blockStoreUpdater = blockManager.
|
||||
TempFileBasedBlockStoreUpdater(blockId, StorageLevel.DISK_ONLY, ct, file, size)
|
||||
blockStoreUpdater.save()
|
||||
} {
|
||||
out.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
super.finalize()
|
||||
path = diskBlockManager.getFile(blockId).getAbsolutePath
|
||||
}
|
||||
|
||||
def setBroadcastId(bid: Long): Unit = {
|
||||
this.broadcastId = bid
|
||||
}
|
||||
|
||||
def setupEncryptionServer(): Array[Any] = {
|
||||
|
@ -783,7 +790,6 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
|
|||
|
||||
def waitTillDataReceived(): Unit = encryptionServer.getResult()
|
||||
}
|
||||
// scalastyle:on no.finalize
|
||||
|
||||
/**
|
||||
* The inverse of pyspark's ChunkedStream for sending data of unknown size.
|
||||
|
|
|
@ -24,6 +24,7 @@ import scala.reflect.ClassTag
|
|||
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}
|
||||
|
||||
import org.apache.spark.{SecurityManager, SparkConf}
|
||||
import org.apache.spark.api.python.PythonBroadcast
|
||||
import org.apache.spark.internal.Logging
|
||||
|
||||
private[spark] class BroadcastManager(
|
||||
|
@ -59,7 +60,18 @@ private[spark] class BroadcastManager(
|
|||
}
|
||||
|
||||
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
|
||||
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
|
||||
val bid = nextBroadcastId.getAndIncrement()
|
||||
value_ match {
|
||||
case pb: PythonBroadcast =>
|
||||
// SPARK-28486: attach this new broadcast variable's id to the PythonBroadcast,
|
||||
// so that underlying data file of PythonBroadcast could be mapped to the
|
||||
// BroadcastBlockId according to this id. Please see the specific usage of the
|
||||
// id in PythonBroadcast.readObject().
|
||||
pb.setBroadcastId(bid)
|
||||
|
||||
case _ => // do nothing
|
||||
}
|
||||
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
|
||||
}
|
||||
|
||||
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
|
||||
|
|
|
@ -211,7 +211,7 @@ private[spark] class BlockManager(
|
|||
*
|
||||
* @param blockSize the decrypted size of the block
|
||||
*/
|
||||
private abstract class BlockStoreUpdater[T](
|
||||
private[spark] abstract class BlockStoreUpdater[T](
|
||||
blockSize: Long,
|
||||
blockId: BlockId,
|
||||
level: StorageLevel,
|
||||
|
@ -357,7 +357,7 @@ private[spark] class BlockManager(
|
|||
/**
|
||||
* Helper for storing a block based from bytes already in a local temp file.
|
||||
*/
|
||||
private case class TempFileBasedBlockStoreUpdater[T](
|
||||
private[spark] case class TempFileBasedBlockStoreUpdater[T](
|
||||
blockId: BlockId,
|
||||
level: StorageLevel,
|
||||
classTag: ClassTag[T],
|
||||
|
|
|
@ -36,6 +36,9 @@ object MimaExcludes {
|
|||
|
||||
// Exclude rules for 3.0.x
|
||||
lazy val v30excludes = v24excludes ++ Seq(
|
||||
// [SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC
|
||||
ProblemFilters.exclude[InaccessibleMethodProblem]("java.lang.Object.finalize"),
|
||||
|
||||
// [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"),
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
|
@ -82,6 +83,23 @@ class BroadcastTest(unittest.TestCase):
|
|||
def test_broadcast_value_driver_encryption(self):
|
||||
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
|
||||
|
||||
def test_broadcast_value_against_gc(self):
|
||||
# Test broadcast value against gc.
|
||||
conf = SparkConf()
|
||||
conf.setMaster("local[1,1]")
|
||||
conf.set("spark.memory.fraction", "0.0001")
|
||||
self.sc = SparkContext(conf=conf)
|
||||
b = self.sc.broadcast([100])
|
||||
try:
|
||||
res = self.sc.parallelize([0], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
|
||||
self.assertEqual([0], res)
|
||||
self.sc._jvm.java.lang.System.gc()
|
||||
time.sleep(5)
|
||||
res = self.sc.parallelize([1], 1).map(lambda x: 0 if x == 0 else b.value[0]).collect()
|
||||
self.assertEqual([100], res)
|
||||
finally:
|
||||
b.destroy()
|
||||
|
||||
|
||||
class BroadcastFrameProtocolTest(unittest.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue