[SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
### What changes were proposed in this pull request?
Try to capture the error message from the `faulthandler` when the Python worker crashes.
### Why are the changes needed?
Currently, we just see an error message saying `"exited unexpectedly (crashed)"` when the UDFs causes the Python worker to crash by like segmentation fault.
We should take advantage of [`faulthandler`](https://docs.python.org/3/library/faulthandler.html) and try to capture the error message from the `faulthandler`.
### Does this PR introduce _any_ user-facing change?
Yes, when a Spark config `spark.python.worker.faulthandler.enabled` is `true`, the stack trace will be seen in the error message when the Python worker crashes.
```py
>>> def f():
... import ctypes
... ctypes.string_at(0)
...
>>> sc.parallelize([1]).map(lambda x: f()).count()
```
```
org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault
Current thread 0x000000010965b5c0 (most recent call first):
File "/.../ctypes/__init__.py", line 525 in string_at
File "<stdin>", line 3 in f
File "<stdin>", line 1 in <lambda>
...
```
### How was this patch tested?
Added some tests, and manually.
Closes #33273 from ueshin/issues/SPARK-36062/faulthandler.
Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 115b8a180f
)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
86676298d3
commit
55111cafd1
|
@ -113,7 +113,9 @@ class SparkEnv (
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark]
|
private[spark]
|
||||||
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
|
def createPythonWorker(
|
||||||
|
pythonExec: String,
|
||||||
|
envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
|
||||||
synchronized {
|
synchronized {
|
||||||
val key = (pythonExec, envVars)
|
val key = (pythonExec, envVars)
|
||||||
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
|
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
|
||||||
|
|
|
@ -21,6 +21,7 @@ import java.io._
|
||||||
import java.net._
|
import java.net._
|
||||||
import java.nio.charset.StandardCharsets
|
import java.nio.charset.StandardCharsets
|
||||||
import java.nio.charset.StandardCharsets.UTF_8
|
import java.nio.charset.StandardCharsets.UTF_8
|
||||||
|
import java.nio.file.{Files => JavaFiles, Path}
|
||||||
import java.util.concurrent.ConcurrentHashMap
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
import java.util.concurrent.atomic.AtomicBoolean
|
import java.util.concurrent.atomic.AtomicBoolean
|
||||||
|
|
||||||
|
@ -65,6 +66,15 @@ private[spark] object PythonEvalType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private object BasePythonRunner {
|
||||||
|
|
||||||
|
private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
|
||||||
|
|
||||||
|
private def faultHandlerLogPath(pid: Int): Path = {
|
||||||
|
new File(faultHandlerLogDir, pid.toString).toPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A helper class to run Python mapPartition/UDFs in Spark.
|
* A helper class to run Python mapPartition/UDFs in Spark.
|
||||||
*
|
*
|
||||||
|
@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
|
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
|
||||||
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
|
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
|
||||||
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
|
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
|
||||||
|
private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
|
||||||
protected val simplifiedTraceback: Boolean = false
|
protected val simplifiedTraceback: Boolean = false
|
||||||
|
|
||||||
// All the Python functions should have the same exec, version and envvars.
|
// All the Python functions should have the same exec, version and envvars.
|
||||||
|
@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
}
|
}
|
||||||
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
|
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
|
||||||
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
|
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
|
||||||
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
|
if (faultHandlerEnabled) {
|
||||||
|
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
|
||||||
|
}
|
||||||
|
|
||||||
|
val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
|
||||||
|
pythonExec, envVars.asScala.toMap)
|
||||||
// Whether is the worker released into idle pool or closed. When any codes try to release or
|
// Whether is the worker released into idle pool or closed. When any codes try to release or
|
||||||
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
|
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
|
||||||
// sure there is only one winner that is going to release or close the worker.
|
// sure there is only one winner that is going to release or close the worker.
|
||||||
|
@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
|
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
|
||||||
|
|
||||||
val stdoutIterator = newReaderIterator(
|
val stdoutIterator = newReaderIterator(
|
||||||
stream, writerThread, startTime, env, worker, releasedOrClosed, context)
|
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
|
||||||
new InterruptibleIterator(context, stdoutIterator)
|
new InterruptibleIterator(context, stdoutIterator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
startTime: Long,
|
startTime: Long,
|
||||||
env: SparkEnv,
|
env: SparkEnv,
|
||||||
worker: Socket,
|
worker: Socket,
|
||||||
|
pid: Option[Int],
|
||||||
releasedOrClosed: AtomicBoolean,
|
releasedOrClosed: AtomicBoolean,
|
||||||
context: TaskContext): Iterator[OUT]
|
context: TaskContext): Iterator[OUT]
|
||||||
|
|
||||||
|
@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
startTime: Long,
|
startTime: Long,
|
||||||
env: SparkEnv,
|
env: SparkEnv,
|
||||||
worker: Socket,
|
worker: Socket,
|
||||||
|
pid: Option[Int],
|
||||||
releasedOrClosed: AtomicBoolean,
|
releasedOrClosed: AtomicBoolean,
|
||||||
context: TaskContext)
|
context: TaskContext)
|
||||||
extends Iterator[OUT] {
|
extends Iterator[OUT] {
|
||||||
|
@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
||||||
logError("This may have been caused by a prior exception:", writerThread.exception.get)
|
logError("This may have been caused by a prior exception:", writerThread.exception.get)
|
||||||
throw writerThread.exception.get
|
throw writerThread.exception.get
|
||||||
|
|
||||||
|
case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
|
||||||
|
JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
|
||||||
|
val path = BasePythonRunner.faultHandlerLogPath(pid.get)
|
||||||
|
val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
|
||||||
|
JavaFiles.deleteIfExists(path)
|
||||||
|
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)
|
||||||
|
|
||||||
case eof: EOFException =>
|
case eof: EOFException =>
|
||||||
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
|
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
|
||||||
}
|
}
|
||||||
|
@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
|
||||||
startTime: Long,
|
startTime: Long,
|
||||||
env: SparkEnv,
|
env: SparkEnv,
|
||||||
worker: Socket,
|
worker: Socket,
|
||||||
|
pid: Option[Int],
|
||||||
releasedOrClosed: AtomicBoolean,
|
releasedOrClosed: AtomicBoolean,
|
||||||
context: TaskContext): Iterator[Array[Byte]] = {
|
context: TaskContext): Iterator[Array[Byte]] = {
|
||||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
new ReaderIterator(
|
||||||
|
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||||
|
|
||||||
protected override def read(): Array[Byte] = {
|
protected override def read(): Array[Byte] = {
|
||||||
if (writerThread.exception.isDefined) {
|
if (writerThread.exception.isDefined) {
|
||||||
|
|
|
@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
||||||
envVars.getOrElse("PYTHONPATH", ""),
|
envVars.getOrElse("PYTHONPATH", ""),
|
||||||
sys.env.getOrElse("PYTHONPATH", ""))
|
sys.env.getOrElse("PYTHONPATH", ""))
|
||||||
|
|
||||||
def create(): Socket = {
|
def create(): (Socket, Option[Int]) = {
|
||||||
if (useDaemon) {
|
if (useDaemon) {
|
||||||
self.synchronized {
|
self.synchronized {
|
||||||
if (idleWorkers.nonEmpty) {
|
if (idleWorkers.nonEmpty) {
|
||||||
return idleWorkers.dequeue()
|
val worker = idleWorkers.dequeue()
|
||||||
|
return (worker, daemonWorkers.get(worker))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
createThroughDaemon()
|
createThroughDaemon()
|
||||||
|
@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
||||||
* processes itself to avoid the high cost of forking from Java. This currently only works
|
* processes itself to avoid the high cost of forking from Java. This currently only works
|
||||||
* on UNIX-based systems.
|
* on UNIX-based systems.
|
||||||
*/
|
*/
|
||||||
private def createThroughDaemon(): Socket = {
|
private def createThroughDaemon(): (Socket, Option[Int]) = {
|
||||||
|
|
||||||
def createSocket(): Socket = {
|
def createSocket(): (Socket, Option[Int]) = {
|
||||||
val socket = new Socket(daemonHost, daemonPort)
|
val socket = new Socket(daemonHost, daemonPort)
|
||||||
val pid = new DataInputStream(socket.getInputStream).readInt()
|
val pid = new DataInputStream(socket.getInputStream).readInt()
|
||||||
if (pid < 0) {
|
if (pid < 0) {
|
||||||
|
@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
||||||
|
|
||||||
authHelper.authToServer(socket)
|
authHelper.authToServer(socket)
|
||||||
daemonWorkers.put(socket, pid)
|
daemonWorkers.put(socket, pid)
|
||||||
socket
|
(socket, Some(pid))
|
||||||
}
|
}
|
||||||
|
|
||||||
self.synchronized {
|
self.synchronized {
|
||||||
|
@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
||||||
/**
|
/**
|
||||||
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
|
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
|
||||||
*/
|
*/
|
||||||
private def createSimpleWorker(): Socket = {
|
private def createSimpleWorker(): (Socket, Option[Int]) = {
|
||||||
var serverSocket: ServerSocket = null
|
var serverSocket: ServerSocket = null
|
||||||
try {
|
try {
|
||||||
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
|
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
|
||||||
|
@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
|
||||||
try {
|
try {
|
||||||
val socket = serverSocket.accept()
|
val socket = serverSocket.accept()
|
||||||
authHelper.authClient(socket)
|
authHelper.authClient(socket)
|
||||||
|
// TODO: When we drop JDK 8, we can just use worker.pid()
|
||||||
|
val pid = new DataInputStream(socket.getInputStream).readInt()
|
||||||
|
if (pid < 0) {
|
||||||
|
throw new IllegalStateException("Python failed to launch worker with code " + pid)
|
||||||
|
}
|
||||||
self.synchronized {
|
self.synchronized {
|
||||||
simpleWorkers.put(socket, worker)
|
simpleWorkers.put(socket, worker)
|
||||||
}
|
}
|
||||||
return socket
|
return (socket, Some(pid))
|
||||||
} catch {
|
} catch {
|
||||||
case e: Exception =>
|
case e: Exception =>
|
||||||
throw new SparkException("Python worker failed to connect back.", e)
|
throw new SparkException("Python worker failed to connect back.", e)
|
||||||
|
|
|
@ -56,4 +56,12 @@ private[spark] object Python {
|
||||||
.version("3.1.0")
|
.version("3.1.0")
|
||||||
.timeConf(TimeUnit.SECONDS)
|
.timeConf(TimeUnit.SECONDS)
|
||||||
.createWithDefaultString("15s")
|
.createWithDefaultString("15s")
|
||||||
|
|
||||||
|
val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled")
|
||||||
|
.doc("When true, Python workers set up the faulthandler for the case when the Python worker " +
|
||||||
|
"exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " +
|
||||||
|
"crashes in the error message if captured successfully.")
|
||||||
|
.version("3.2.0")
|
||||||
|
.booleanConf
|
||||||
|
.createWithDefault(false)
|
||||||
}
|
}
|
||||||
|
|
|
@ -206,6 +206,35 @@ class WorkerMemoryTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.sc.stop()
|
self.sc.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerSegfaultTest(ReusedPySparkTestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def conf(cls):
|
||||||
|
_conf = super(WorkerSegfaultTest, cls).conf()
|
||||||
|
_conf.set("spark.python.worker.faulthandler.enabled", "true")
|
||||||
|
return _conf
|
||||||
|
|
||||||
|
def test_python_segfault(self):
|
||||||
|
try:
|
||||||
|
def f():
|
||||||
|
import ctypes
|
||||||
|
ctypes.string_at(0)
|
||||||
|
|
||||||
|
self.sc.parallelize([1]).map(lambda x: f()).count()
|
||||||
|
except Py4JJavaError as e:
|
||||||
|
self.assertRegex(str(e), "Segmentation fault")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def conf(cls):
|
||||||
|
_conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
|
||||||
|
_conf.set("spark.python.use.daemon", "false")
|
||||||
|
return _conf
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
import unittest
|
||||||
from pyspark.tests.test_worker import * # noqa: F401
|
from pyspark.tests.test_worker import * # noqa: F401
|
||||||
|
|
|
@ -31,6 +31,7 @@ except ImportError:
|
||||||
has_resource_module = False
|
has_resource_module = False
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
import faulthandler
|
||||||
|
|
||||||
from pyspark.accumulators import _accumulatorRegistry
|
from pyspark.accumulators import _accumulatorRegistry
|
||||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||||
|
@ -463,7 +464,13 @@ def read_udfs(pickleSer, infile, eval_type):
|
||||||
|
|
||||||
|
|
||||||
def main(infile, outfile):
|
def main(infile, outfile):
|
||||||
|
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
|
||||||
try:
|
try:
|
||||||
|
if faulthandler_log_path:
|
||||||
|
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
|
||||||
|
faulthandler_log_file = open(faulthandler_log_path, "w")
|
||||||
|
faulthandler.enable(file=faulthandler_log_file)
|
||||||
|
|
||||||
boot_time = time.time()
|
boot_time = time.time()
|
||||||
split_index = read_int(infile)
|
split_index = read_int(infile)
|
||||||
if split_index == -1: # for unit tests
|
if split_index == -1: # for unit tests
|
||||||
|
@ -636,6 +643,11 @@ def main(infile, outfile):
|
||||||
print("PySpark worker failed with exception:", file=sys.stderr)
|
print("PySpark worker failed with exception:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
finally:
|
||||||
|
if faulthandler_log_path:
|
||||||
|
faulthandler.disable()
|
||||||
|
faulthandler_log_file.close()
|
||||||
|
os.remove(faulthandler_log_path)
|
||||||
finish_time = time.time()
|
finish_time = time.time()
|
||||||
report_times(outfile, boot_time, init_time, finish_time)
|
report_times(outfile, boot_time, init_time, finish_time)
|
||||||
write_long(shuffle.MemoryBytesSpilled, outfile)
|
write_long(shuffle.MemoryBytesSpilled, outfile)
|
||||||
|
@ -661,4 +673,7 @@ if __name__ == '__main__':
|
||||||
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
|
||||||
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
|
||||||
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
|
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
|
||||||
|
# TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8.
|
||||||
|
write_int(os.getpid(), sock_file)
|
||||||
|
sock_file.flush()
|
||||||
main(sock_file, sock_file)
|
main(sock_file, sock_file)
|
||||||
|
|
|
@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
|
||||||
startTime: Long,
|
startTime: Long,
|
||||||
env: SparkEnv,
|
env: SparkEnv,
|
||||||
worker: Socket,
|
worker: Socket,
|
||||||
|
pid: Option[Int],
|
||||||
releasedOrClosed: AtomicBoolean,
|
releasedOrClosed: AtomicBoolean,
|
||||||
context: TaskContext): Iterator[ColumnarBatch] = {
|
context: TaskContext): Iterator[ColumnarBatch] = {
|
||||||
|
|
||||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
new ReaderIterator(
|
||||||
|
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||||
|
|
||||||
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
|
||||||
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
s"stdin reader for $pythonExec", 0, Long.MaxValue)
|
||||||
|
|
|
@ -62,9 +62,11 @@ class PythonUDFRunner(
|
||||||
startTime: Long,
|
startTime: Long,
|
||||||
env: SparkEnv,
|
env: SparkEnv,
|
||||||
worker: Socket,
|
worker: Socket,
|
||||||
|
pid: Option[Int],
|
||||||
releasedOrClosed: AtomicBoolean,
|
releasedOrClosed: AtomicBoolean,
|
||||||
context: TaskContext): Iterator[Array[Byte]] = {
|
context: TaskContext): Iterator[Array[Byte]] = {
|
||||||
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
|
new ReaderIterator(
|
||||||
|
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
|
||||||
|
|
||||||
protected override def read(): Array[Byte] = {
|
protected override def read(): Array[Byte] = {
|
||||||
if (writerThread.exception.isDefined) {
|
if (writerThread.exception.isDefined) {
|
||||||
|
|
Loading…
Reference in a new issue