[SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages
## What changes were proposed in this pull request? The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes. This adds an `onTaskCommit` listener to the internal api. ## How was this patch tested? Unit tests. cc rxin Author: Eric Liang <ekl@databricks.com> Closes #17475 from ericl/file-commit-api-ext.
This commit is contained in:
parent
60977889ea
commit
79636054f6
|
@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
|
|||
def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
|
||||
fs.delete(path, recursive)
|
||||
}
|
||||
|
||||
/**
|
||||
* Called on the driver after a task commits. This can be used to access task commit messages
|
||||
* before the job has finished. These same task commit messages will be passed to commitJob()
|
||||
* if the entire job succeeds.
|
||||
*/
|
||||
def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
|
|||
""".stripMargin)
|
||||
}
|
||||
|
||||
/** The result of a successful write task. */
|
||||
private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String])
|
||||
|
||||
/**
|
||||
* Basic work flow of this command is:
|
||||
* 1. Driver side setup, including output committer initialization and data source specific
|
||||
|
@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
|
|||
global = false,
|
||||
child = queryExecution.executedPlan).execute()
|
||||
}
|
||||
|
||||
val ret = sparkSession.sparkContext.runJob(rdd,
|
||||
val ret = new Array[WriteTaskResult](rdd.partitions.length)
|
||||
sparkSession.sparkContext.runJob(
|
||||
rdd,
|
||||
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
|
||||
executeTask(
|
||||
description = description,
|
||||
|
@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
|
|||
sparkAttemptNumber = taskContext.attemptNumber(),
|
||||
committer,
|
||||
iterator = iter)
|
||||
},
|
||||
0 until rdd.partitions.length,
|
||||
(index, res: WriteTaskResult) => {
|
||||
committer.onTaskCommit(res.commitMsg)
|
||||
ret(index) = res
|
||||
})
|
||||
|
||||
val commitMsgs = ret.map(_._1)
|
||||
val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment)
|
||||
val commitMsgs = ret.map(_.commitMsg)
|
||||
val updatedPartitions = ret.flatMap(_.updatedPartitions)
|
||||
.distinct.map(PartitioningUtils.parsePathFragment)
|
||||
|
||||
committer.commitJob(job, commitMsgs)
|
||||
logInfo(s"Job ${job.getJobID} committed.")
|
||||
|
@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
|
|||
sparkPartitionId: Int,
|
||||
sparkAttemptNumber: Int,
|
||||
committer: FileCommitProtocol,
|
||||
iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = {
|
||||
iterator: Iterator[InternalRow]): WriteTaskResult = {
|
||||
|
||||
val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
|
||||
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
|
||||
|
@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
|
|||
// Execute the task to write rows out and commit the task.
|
||||
val outputPartitions = writeTask.execute(iterator)
|
||||
writeTask.releaseResources()
|
||||
(committer.commitTask(taskAttemptContext), outputPartitions)
|
||||
WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions)
|
||||
})(catchBlock = {
|
||||
// If there is an error, release resource and then abort the task
|
||||
try {
|
||||
|
|
|
@ -18,9 +18,12 @@
|
|||
package org.apache.spark.sql.test
|
||||
|
||||
import java.io.File
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
|
||||
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.sources._
|
||||
|
@ -41,7 +44,6 @@ object LastOptions {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/** Dummy provider. */
|
||||
class DefaultSource
|
||||
extends RelationProvider
|
||||
|
@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
|
|||
}
|
||||
}
|
||||
|
||||
object MessageCapturingCommitProtocol {
|
||||
val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
|
||||
}
|
||||
|
||||
class MessageCapturingCommitProtocol(jobId: String, path: String)
|
||||
extends HadoopMapReduceCommitProtocol(jobId, path) {
|
||||
|
||||
// captures commit messages for testing
|
||||
override def onTaskCommit(msg: TaskCommitMessage): Unit = {
|
||||
MessageCapturingCommitProtocol.commitMessages.offer(msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
|
||||
import testImplicits._
|
||||
|
||||
|
@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
|
|||
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
|
||||
}
|
||||
|
||||
test("write path implements onTaskCommit API correctly") {
|
||||
withSQLConf(
|
||||
"spark.sql.sources.commitProtocolClass" ->
|
||||
classOf[MessageCapturingCommitProtocol].getCanonicalName) {
|
||||
withTempDir { dir =>
|
||||
val path = dir.getCanonicalPath
|
||||
MessageCapturingCommitProtocol.commitMessages.clear()
|
||||
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
|
||||
assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("read a data source that does not extend SchemaRelationProvider") {
|
||||
val dfReader = spark.read
|
||||
.option("from", "1")
|
||||
|
|
Loading…
Reference in a new issue