diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 296179b75b..085829af6e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1111,9 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close(hadoopContext) - } + }(finallyBlock = writer.close(hadoopContext)) committer.commitTask(hadoopContext) outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => om.setBytesWritten(callback()) @@ -1200,9 +1198,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close() - } + }(finallyBlock = writer.close()) writer.commit() outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => om.setBytesWritten(callback()) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46c64f61de..c91d8fbfc4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -80,10 +80,16 @@ private[spark] abstract class Task[T]( } try { runTask(context) - } catch { case e: Throwable => - // Catch all errors; run task failure callbacks, and rethrow the exception. - context.markTaskFailed(e) - throw e + } catch { + case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + try { + context.markTaskFailed(e) + } catch { + case t: Throwable => + e.addSuppressed(t) + } + throw e } finally { // Call the task completion callbacks. context.markTaskCompleted() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c304629bcd..78e164cff7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1260,26 +1260,35 @@ private[spark] object Utils extends Logging { } /** - * Execute a block of code, call the failure callbacks before finally block if there is any - * exceptions happen. But if exceptions happen in the finally block, do not suppress the original - * exception. + * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur + * in either the catch or the finally block, they are appended to the list of suppressed + * exceptions in original exception which is then rethrown. * - * This is primarily an issue with `finally { out.close() }` blocks, where - * close needs to be called to clean up `out`, but if an exception happened - * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks, + * where the abort/close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will * fail as well. This would then suppress the original/likely more meaningful * exception from the original `out.write` call. */ - def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = { + def tryWithSafeFinallyAndFailureCallbacks[T](block: => T) + (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = { var originalThrowable: Throwable = null try { block } catch { - case t: Throwable => + case cause: Throwable => // Purposefully not using NonFatal, because even fatal exceptions // we don't want to have our finallyBlock suppress - originalThrowable = t - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t) + originalThrowable = cause + try { + logError("Aborting task", originalThrowable) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) + catchBlock + } catch { + case t: Throwable => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + } throw originalThrowable } finally { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index d2bbf196cb..b9a3162aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} /** A container for all the details required when writing to a table. */ case class WriteRelation( @@ -247,19 +247,16 @@ private[sql] class DefaultWriterContainer( // If anything below fails, we should abort the task. try { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - - commitTask() + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + commitTask() + }(catchBlock = abortTask()) } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } def commitTask(): Unit = { @@ -413,37 +410,37 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. var currentWriter: OutputWriter = null try { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null + Utils.tryWithSafeFinallyAndFailureCallbacks { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - commitTask() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + commitTask() + }(catchBlock = { if (currentWriter != null) { currentWriter.close() } abortTask() - throw new SparkException("Task failed while writing rows.", cause) + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } } }