[SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
### What changes were proposed in this pull request?
Try to capture the error message from the `faulthandler` when the Python worker crashes.
### Why are the changes needed?
Currently, we just see an error message saying `"exited unexpectedly (crashed)"` when the UDFs causes the Python worker to crash by like segmentation fault.
We should take advantage of [`faulthandler`](https://docs.python.org/3/library/faulthandler.html) and try to capture the error message from the `faulthandler`.
### Does this PR introduce _any_ user-facing change?
Yes, when a Spark config `spark.python.worker.faulthandler.enabled` is `true`, the stack trace will be seen in the error message when the Python worker crashes.
```py
>>> def f():
... import ctypes
... ctypes.string_at(0)
...
>>> sc.parallelize([1]).map(lambda x: f()).count()
```
```
org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault
Current thread 0x000000010965b5c0 (most recent call first):
File "/.../ctypes/__init__.py", line 525 in string_at
File "<stdin>", line 3 in f
File "<stdin>", line 1 in <lambda>
...
```
### How was this patch tested?
Added some tests, and manually.
Closes #33273 from ueshin/issues/SPARK-36062/faulthandler.
Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 115b8a180f
)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
86676298d3
commit
55111cafd1
|
@ -113,7 +113,9 @@ class SparkEnv (
|
|||
}
|
||||
|
||||
private[spark]
|
||||
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
|
||||
def createPythonWorker(
|
||||
pythonExec: String,
|
||||
envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
|
||||
synchronized {
|
||||
val key = (pythonExec, envVars)
|
||||
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.io._
|
|||
import java.net._
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.nio.file.{Files => JavaFiles, Path}
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
|
@ -65,6 +66,15 @@ private[spark] object PythonEvalType {
|
|||
}
|
||||
}
|
||||
|
||||
private object BasePythonRunner {
|
||||
|
||||
private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
|
||||
|
||||
private def faultHandlerLogPath(pid: Int): Path = {
|
||||
new File(faultHandlerLogDir, pid.toString).toPath
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class to run Python mapPartition/UDFs in Spark.
|
||||
*
|
||||
|
@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
|
||||
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
|
||||
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
|
||||
private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
|
||||
protected val simplifiedTraceback: Boolean = false
|
||||
|
||||
// All the Python functions should have the same exec, version and envvars.
|
||||
|
@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
}
|
||||
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
|
||||
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
|
||||
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
|
||||
if (faultHandlerEnabled) {
|
||||
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
|
||||
}
|
||||
|
||||
val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
|
||||
pythonExec, envVars.asScala.toMap)
|
||||
// Whether is the worker released into idle pool or closed. When any codes try to release or
|
||||
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
|
||||
// sure there is only one winner that is going to release or close the worker.
|
||||
|
@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
|
||||
|
||||
val stdoutIterator = newReaderIterator(
|
||||
stream, writerThread, startTime, env, worker, releasedOrClosed, context)
|
||||
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
|
||||
new InterruptibleIterator(context, stdoutIterator)
|
||||
}
|
||||
|
||||
|
@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
pid: Option[Int],
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[OUT]
|
||||
|
||||
|
@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
pid: Option[Int],
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext)
|
||||
extends Iterator[OUT] {
|
||||
|
@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
logError("This may have been caused by a prior exception:", writerThread.exception.get)
|
||||
throw writerThread.exception.get
|
||||
|
||||
case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
|
||||
JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
|
||||
val path = BasePythonRunner.faultHandlerLogPath(pid.get)
|
||||
val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
|
||||
JavaFiles.deleteIfExists(path)
|
||||
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)
|
||||
|
||||
case eof: EOFException =>
|
||||
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
|
||||
}
|
||||
|
@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
|
|||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
pid: Option[Int],
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[Array[Byte]] = {
|
||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||
new ReaderIterator(
|
||||
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||
|
||||
protected override def read(): Array[Byte] = {
|
||||
if (writerThread.exception.isDefined) {
|
||||
|
|
|
@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
envVars.getOrElse("PYTHONPATH", ""),
|
||||
sys.env.getOrElse("PYTHONPATH", ""))
|
||||
|
||||
def create(): Socket = {
|
||||
def create(): (Socket, Option[Int]) = {
|
||||
if (useDaemon) {
|
||||
self.synchronized {
|
||||
if (idleWorkers.nonEmpty) {
|
||||
return idleWorkers.dequeue()
|
||||
val worker = idleWorkers.dequeue()
|
||||
return (worker, daemonWorkers.get(worker))
|
||||
}
|
||||
}
|
||||
createThroughDaemon()
|
||||
|
@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
* processes itself to avoid the high cost of forking from Java. This currently only works
|
||||
* on UNIX-based systems.
|
||||
*/
|
||||
private def createThroughDaemon(): Socket = {
|
||||
private def createThroughDaemon(): (Socket, Option[Int]) = {
|
||||
|
||||
def createSocket(): Socket = {
|
||||
def createSocket(): (Socket, Option[Int]) = {
|
||||
val socket = new Socket(daemonHost, daemonPort)
|
||||
val pid = new DataInputStream(socket.getInputStream).readInt()
|
||||
if (pid < 0) {
|
||||
|
@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
|
||||
authHelper.authToServer(socket)
|
||||
daemonWorkers.put(socket, pid)
|
||||
socket
|
||||
(socket, Some(pid))
|
||||
}
|
||||
|
||||
self.synchronized {
|
||||
|
@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
/**
|
||||
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
|
||||
*/
|
||||
private def createSimpleWorker(): Socket = {
|
||||
private def createSimpleWorker(): (Socket, Option[Int]) = {
|
||||
var serverSocket: ServerSocket = null
|
||||
try {
|
||||
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
|
||||
|
@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
|||
try {
|
||||
val socket = serverSocket.accept()
|
||||
authHelper.authClient(socket)
|
||||
// TODO: When we drop JDK 8, we can just use worker.pid()
|
||||
val pid = new DataInputStream(socket.getInputStream).readInt()
|
||||
if (pid < 0) {
|
||||
throw new IllegalStateException("Python failed to launch worker with code " + pid)
|
||||
}
|
||||
self.synchronized {
|
||||
simpleWorkers.put(socket, worker)
|
||||
}
|
||||
return socket
|
||||
return (socket, Some(pid))
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
throw new SparkException("Python worker failed to connect back.", e)
|
||||
|
|
|
@ -56,4 +56,12 @@ private[spark] object Python {
|
|||
.version("3.1.0")
|
||||
.timeConf(TimeUnit.SECONDS)
|
||||
.createWithDefaultString("15s")
|
||||
|
||||
val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled")
|
||||
.doc("When true, Python workers set up the faulthandler for the case when the Python worker " +
|
||||
"exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " +
|
||||
"crashes in the error message if captured successfully.")
|
||||
.version("3.2.0")
|
||||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
}
|
||||
|
|
|
@ -206,6 +206,35 @@ class WorkerMemoryTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
self.sc.stop()
|
||||
|
||||
|
||||
class WorkerSegfaultTest(ReusedPySparkTestCase):
|
||||
|
||||
@classmethod
|
||||
def conf(cls):
|
||||
_conf = super(WorkerSegfaultTest, cls).conf()
|
||||
_conf.set("spark.python.worker.faulthandler.enabled", "true")
|
||||
return _conf
|
||||
|
||||
def test_python_segfault(self):
|
||||
try:
|
||||
def f():
|
||||
import ctypes
|
||||
ctypes.string_at(0)
|
||||
|
||||
self.sc.parallelize([1]).map(lambda x: f()).count()
|
||||
except Py4JJavaError as e:
|
||||
self.assertRegex(str(e), "Segmentation fault")
|
||||
|
||||
|
||||
class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):
|
||||
|
||||
@classmethod
|
||||
def conf(cls):
|
||||
_conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
|
||||
_conf.set("spark.python.use.daemon", "false")
|
||||
return _conf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.tests.test_worker import * # noqa: F401
|
||||
|
|
|
@ -31,6 +31,7 @@ except ImportError:
|
|||
has_resource_module = False
|
||||
import traceback
|
||||
import warnings
|
||||
import faulthandler
|
||||
|
||||
from pyspark.accumulators import _accumulatorRegistry
|
||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||
|
@ -463,7 +464,13 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
|
||||
|
||||
def main(infile, outfile):
|
||||
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
|
||||
try:
|
||||
if faulthandler_log_path:
|
||||
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
|
||||
faulthandler_log_file = open(faulthandler_log_path, "w")
|
||||
faulthandler.enable(file=faulthandler_log_file)
|
||||
|
||||
boot_time = time.time()
|
||||
split_index = read_int(infile)
|
||||
if split_index == -1: # for unit tests
|
||||
|
@ -636,6 +643,11 @@ def main(infile, outfile):
|
|||
print("PySpark worker failed with exception:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
sys.exit(-1)
|
||||
finally:
|
||||
if faulthandler_log_path:
|
||||
faulthandler.disable()
|
||||
faulthandler_log_file.close()
|
||||
os.remove(faulthandler_log_path)
|
||||
finish_time = time.time()
|
||||
report_times(outfile, boot_time, init_time, finish_time)
|
||||
write_long(shuffle.MemoryBytesSpilled, outfile)
|
||||
|
@ -661,4 +673,7 @@ if __name__ == '__main__':
|
|||
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
||||
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
||||
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
|
||||
# TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8.
|
||||
write_int(os.getpid(), sock_file)
|
||||
sock_file.flush()
|
||||
main(sock_file, sock_file)
|
||||
|
|
|
@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
|
|||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
pid: Option[Int],
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[ColumnarBatch] = {
|
||||
|
||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||
new ReaderIterator(
|
||||
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||
|
||||
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
||||
|
|
|
@ -62,9 +62,11 @@ class PythonUDFRunner(
|
|||
startTime: Long,
|
||||
env: SparkEnv,
|
||||
worker: Socket,
|
||||
pid: Option[Int],
|
||||
releasedOrClosed: AtomicBoolean,
|
||||
context: TaskContext): Iterator[Array[Byte]] = {
|
||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
||||
new ReaderIterator(
|
||||
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||
|
||||
protected override def read(): Array[Byte] = {
|
||||
if (writerThread.exception.isDefined) {
|
||||
|
|
Loading…
Reference in a new issue