Merge pull request #434 from pwendell/python-exceptions
SPARK-673: Capture and re-throw Python exceptions
This commit is contained in:
commit
7e2e046e37
|
@ -103,21 +103,27 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
private def read(): Array[Byte] = {
|
||||
try {
|
||||
val length = stream.readInt()
|
||||
if (length != -1) {
|
||||
val obj = new Array[Byte](length)
|
||||
stream.readFully(obj)
|
||||
obj
|
||||
} else {
|
||||
// We've finished the data section of the output, but we can still read some
|
||||
// accumulator updates; let's do that, breaking when we get EOFException
|
||||
while (true) {
|
||||
val len2 = stream.readInt()
|
||||
val update = new Array[Byte](len2)
|
||||
stream.readFully(update)
|
||||
accumulator += Collections.singletonList(update)
|
||||
}
|
||||
new Array[Byte](0)
|
||||
stream.readInt() match {
|
||||
case length if length > 0 =>
|
||||
val obj = new Array[Byte](length)
|
||||
stream.readFully(obj)
|
||||
obj
|
||||
case -2 =>
|
||||
// Signals that an exception has been thrown in python
|
||||
val exLength = stream.readInt()
|
||||
val obj = new Array[Byte](exLength)
|
||||
stream.readFully(obj)
|
||||
throw new PythonException(new String(obj))
|
||||
case -1 =>
|
||||
// We've finished the data section of the output, but we can still read some
|
||||
// accumulator updates; let's do that, breaking when we get EOFException
|
||||
while (true) {
|
||||
val len2 = stream.readInt()
|
||||
val update = new Array[Byte](len2)
|
||||
stream.readFully(update)
|
||||
accumulator += Collections.singletonList(update)
|
||||
}
|
||||
new Array[Byte](0)
|
||||
}
|
||||
} catch {
|
||||
case eof: EOFException => {
|
||||
|
@ -140,6 +146,9 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
||||
}
|
||||
|
||||
/** Thrown for exceptions in user Python code. */
|
||||
private class PythonException(msg: String) extends Exception(msg)
|
||||
|
||||
/**
|
||||
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
|
||||
* This is used by PySpark's shuffle operations.
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Worker that receives input from Piped RDD.
|
||||
"""
|
||||
import sys
|
||||
import traceback
|
||||
from base64 import standard_b64decode
|
||||
# CloudPickler needs to be imported so that depicklers are registered using the
|
||||
# copy_reg module.
|
||||
|
@ -40,8 +41,13 @@ def main():
|
|||
else:
|
||||
dumps = dump_pickle
|
||||
iterator = read_from_pickle_file(sys.stdin)
|
||||
for obj in func(split_index, iterator):
|
||||
write_with_length(dumps(obj), old_stdout)
|
||||
try:
|
||||
for obj in func(split_index, iterator):
|
||||
write_with_length(dumps(obj), old_stdout)
|
||||
except Exception as e:
|
||||
write_int(-2, old_stdout)
|
||||
write_with_length(traceback.format_exc(), old_stdout)
|
||||
sys.exit(-1)
|
||||
# Mark the beginning of the accumulators section of the output
|
||||
write_int(-1, old_stdout)
|
||||
for aid, accum in _accumulatorRegistry.items():
|
||||
|
|
Loading…
Reference in a new issue