diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ed8dc43b16..ee50a8f836 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -113,7 +113,9 @@ class SparkEnv ( } private[spark] - def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { + def createPythonWorker( + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { synchronized { val key = (pythonExec, envVars) pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 6e2b6add96..db0e1003f2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -21,6 +21,7 @@ import java.io._ import java.net._ import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files => JavaFiles, Path} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean @@ -65,6 +66,15 @@ private[spark] object PythonEvalType { } } +private object BasePythonRunner { + + private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") + + private def faultHandlerLogPath(pid: Int): Path = { + new File(faultHandlerLogDir, pid.toString).toPath + } +} + /** * A helper class to run Python mapPartition/UDFs in Spark. * @@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val bufferSize: Int = conf.get(BUFFER_SIZE) protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) + private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) protected val simplifiedTraceback: Boolean = false // All the Python functions should have the same exec, version and envvars. @@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + if (faultHandlerEnabled) { + envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString) + } + + val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( + pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make // sure there is only one winner that is going to release or close the worker. @@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, releasedOrClosed, context) + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) new InterruptibleIterator(context, stdoutIterator) } @@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: Socket, + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] @@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: Socket, + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) extends Iterator[OUT] { @@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( logError("This may have been caused by a prior exception:", writerThread.exception.get) throw writerThread.exception.get + case eof: EOFException if faultHandlerEnabled && pid.isDefined && + JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) => + val path = BasePythonRunner.faultHandlerLogPath(pid.get) + val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n" + JavaFiles.deleteIfExists(path) + throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof) + case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } @@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) startTime: Long, env: SparkEnv, worker: Socket, + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { if (writerThread.exception.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index df236ba892..7b2c36bb10 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): Socket = { + def create(): (Socket, Option[Int]) = { if (useDaemon) { self.synchronized { if (idleWorkers.nonEmpty) { - return idleWorkers.dequeue() + val worker = idleWorkers.dequeue() + return (worker, daemonWorkers.get(worker)) } } createThroughDaemon() @@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): Socket = { + private def createThroughDaemon(): (Socket, Option[Int]) = { - def createSocket(): Socket = { + def createSocket(): (Socket, Option[Int]) = { val socket = new Socket(daemonHost, daemonPort) val pid = new DataInputStream(socket.getInputStream).readInt() if (pid < 0) { @@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String authHelper.authToServer(socket) daemonWorkers.put(socket, pid) - socket + (socket, Some(pid)) } self.synchronized { @@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private def createSimpleWorker(): Socket = { + private def createSimpleWorker(): (Socket, Option[Int]) = { var serverSocket: ServerSocket = null try { serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) @@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { val socket = serverSocket.accept() authHelper.authClient(socket) + // TODO: When we drop JDK 8, we can just use worker.pid() + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python failed to launch worker with code " + pid) + } self.synchronized { simpleWorkers.put(socket, worker) } - return socket + return (socket, Some(pid)) } catch { case e: Exception => throw new SparkException("Python worker failed to connect back.", e) diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 348a33e129..5e026fdd3b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -56,4 +56,12 @@ private[spark] object Python { .version("3.1.0") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("15s") + + val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled") + .doc("When true, Python workers set up the faulthandler for the case when the Python worker " + + "exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " + + "crashes in the error message if captured successfully.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) } diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 120c5e36fe..a77d38e286 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -206,6 +206,35 @@ class WorkerMemoryTest(unittest.TestCase): def tearDown(self): self.sc.stop() + +class WorkerSegfaultTest(ReusedPySparkTestCase): + + @classmethod + def conf(cls): + _conf = super(WorkerSegfaultTest, cls).conf() + _conf.set("spark.python.worker.faulthandler.enabled", "true") + return _conf + + def test_python_segfault(self): + try: + def f(): + import ctypes + ctypes.string_at(0) + + self.sc.parallelize([1]).map(lambda x: f()).count() + except Py4JJavaError as e: + self.assertRegex(str(e), "Segmentation fault") + + +class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest): + + @classmethod + def conf(cls): + _conf = super(WorkerSegfaultNonDaemonTest, cls).conf() + _conf.set("spark.python.use.daemon", "false") + return _conf + + if __name__ == "__main__": import unittest from pyspark.tests.test_worker import * # noqa: F401 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 023a6553c8..a13717f473 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,6 +31,7 @@ except ImportError: has_resource_module = False import traceback import warnings +import faulthandler from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -463,7 +464,13 @@ def read_udfs(pickleSer, infile, eval_type): def main(infile, outfile): + faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) try: + if faulthandler_log_path: + faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) + faulthandler_log_file = open(faulthandler_log_path, "w") + faulthandler.enable(file=faulthandler_log_file) + boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests @@ -636,6 +643,11 @@ def main(infile, outfile): print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) + finally: + if faulthandler_log_path: + faulthandler.disable() + faulthandler_log_file.close() + os.remove(faulthandler_log_path) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -661,4 +673,7 @@ if __name__ == '__main__': java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + # TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8. + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index bb35306238..00bab1e9fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc startTime: Long, env: SparkEnv, worker: Socket, + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdin reader for $pythonExec", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index d9fe07214d..d1109d251c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -62,9 +62,11 @@ class PythonUDFRunner( startTime: Long, env: SparkEnv, worker: Socket, + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { if (writerThread.exception.isDefined) {