[SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages

## What changes were proposed in this pull request?

The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes.

This adds an `onTaskCommit` listener to the internal api.

## How was this patch tested?

Unit tests.

cc rxin

Author: Eric Liang <ekl@databricks.com>

Closes #17475 from ericl/file-commit-api-ext.
This commit is contained in:
Eric Liang 2017-03-29 20:59:48 -07:00 committed by Reynold Xin
parent 60977889ea
commit 79636054f6
3 changed files with 53 additions and 7 deletions

View file

@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
fs.delete(path, recursive)
}
/**
* Called on the driver after a task commits. This can be used to access task commit messages
* before the job has finished. These same task commit messages will be passed to commitJob()
* if the entire job succeeds.
*/
def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
}

View file

@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
""".stripMargin)
}
/** The result of a successful write task. */
private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String])
/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
global = false,
child = queryExecution.executedPlan).execute()
}
val ret = sparkSession.sparkContext.runJob(rdd,
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
sparkAttemptNumber = taskContext.attemptNumber(),
committer,
iterator = iter)
},
0 until rdd.partitions.length,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
})
val commitMsgs = ret.map(_._1)
val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment)
val commitMsgs = ret.map(_.commitMsg)
val updatedPartitions = ret.flatMap(_.updatedPartitions)
.distinct.map(PartitioningUtils.parsePathFragment)
committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.")
@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
sparkPartitionId: Int,
sparkAttemptNumber: Int,
committer: FileCommitProtocol,
iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = {
iterator: Iterator[InternalRow]): WriteTaskResult = {
val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
// Execute the task to write rows out and commit the task.
val outputPartitions = writeTask.execute(iterator)
writeTask.releaseResources()
(committer.commitTask(taskAttemptContext), outputPartitions)
WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions)
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {

View file

@ -18,9 +18,12 @@
package org.apache.spark.sql.test
import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue
import org.scalatest.BeforeAndAfter
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
@ -41,7 +44,6 @@ object LastOptions {
}
}
/** Dummy provider. */
class DefaultSource
extends RelationProvider
@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
}
}
object MessageCapturingCommitProtocol {
val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
}
class MessageCapturingCommitProtocol(jobId: String, path: String)
extends HadoopMapReduceCommitProtocol(jobId, path) {
// captures commit messages for testing
override def onTaskCommit(msg: TaskCommitMessage): Unit = {
MessageCapturingCommitProtocol.commitMessages.offer(msg)
}
}
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
}
test("write path implements onTaskCommit API correctly") {
withSQLConf(
"spark.sql.sources.commitProtocolClass" ->
classOf[MessageCapturingCommitProtocol].getCanonicalName) {
withTempDir { dir =>
val path = dir.getCanonicalPath
MessageCapturingCommitProtocol.commitMessages.clear()
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
}
}
}
test("read a data source that does not extend SchemaRelationProvider") {
val dfReader = spark.read
.option("from", "1")