[SPARK-24330][SQL] Refactor ExecuteWriteTask and Use while
in writing files
## What changes were proposed in this pull request? 1. Refactor ExecuteWriteTask in FileFormatWriter to reduce common logic and improve readability. After the change, callers only need to call `commit()` or `abort` at the end of task. Also there is less code in `SingleDirectoryWriteTask` and `DynamicPartitionWriteTask`. Definitions of related classes are moved to a new file, and `ExecuteWriteTask` is renamed to `FileFormatDataWriter`. 2. As per code style guide: https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex , we avoid using `for` for looping in [FileFormatWriter](https://github.com/apache/spark/pull/21381/files#diff-3b69eb0963b68c65cfe8075f8a42e850L536) , or `foreach` in [WriteToDataSourceV2Exec](https://github.com/apache/spark/pull/21381/files#diff-6fbe10db766049a395bae2e785e9d56eL119). In such critical code path, using `while` is good for performance. ## How was this patch tested? Existing unit test. I tried the microbenchmark in https://github.com/apache/spark/pull/21409 | Workload | Before changes(Best/Avg Time(ms)) | After changes(Best/Avg Time(ms)) | | --- | --- | -- | |Output Single Int Column| 2018 / 2043 | 2096 / 2236 | |Output Single Double Column| 1978 / 2043 | 2013 / 2018 | |Output Int and String Column| 6332 / 6706 | 6162 / 6298 | |Output Partitions| 4458 / 5094 | 3792 / 4008 | |Output Buckets| 5695 / 6102 | 5120 / 5154 | Also a microbenchmark on my laptop for general comparison among while/foreach/for : ``` class Writer { var sum = 0L def write(l: Long): Unit = sum += l } def testWhile(iterator: Iterator[Long]): Long = { val w = new Writer while (iterator.hasNext) { w.write(iterator.next()) } w.sum } def testForeach(iterator: Iterator[Long]): Long = { val w = new Writer iterator.foreach(w.write) w.sum } def testFor(iterator: Iterator[Long]): Long = { val w = new Writer for (x <- iterator) { w.write(x) } w.sum } val data = 0L to 100000000L val start = System.nanoTime (0 to 10).foreach(_ => testWhile(data.iterator)) println("benchmark while: " + (System.nanoTime - start)/1000000) val start2 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark foreach: " + (System.nanoTime - start2)/1000000) val start3 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark for: " + (System.nanoTime - start3)/1000000) ``` Benchmark result: `while`: 15401 ms `foreach`: 43034 ms `for`: 41279 ms Author: Gengliang Wang <gengliang.wang@databricks.com> Closes #21381 from gengliangwang/refactorExecuteWriteTask.
This commit is contained in:
parent
2c9c8629b7
commit
cbaa729132
|
@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
|
|||
|
||||
|
||||
/**
|
||||
* Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]].
|
||||
* Simple metrics collected during an instance of [[FileFormatDataWriter]].
|
||||
* These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703).
|
||||
*/
|
||||
case class BasicWriteTaskStats(
|
||||
|
|
|
@ -0,0 +1,313 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.datasources
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.mapreduce.TaskAttemptContext
|
||||
|
||||
import org.apache.spark.internal.io.FileCommitProtocol
|
||||
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.StringType
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
/**
|
||||
* Abstract class for writing out data in a single Spark task.
|
||||
* Exceptions thrown by the implementation of this trait will automatically trigger task aborts.
|
||||
*/
|
||||
abstract class FileFormatDataWriter(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol) {
|
||||
/**
|
||||
* Max number of files a single task writes out due to file size. In most cases the number of
|
||||
* files written should be very small. This is just a safe guard to protect some really bad
|
||||
* settings, e.g. maxRecordsPerFile = 1.
|
||||
*/
|
||||
protected val MAX_FILE_COUNTER: Int = 1000 * 1000
|
||||
protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
|
||||
protected var currentWriter: OutputWriter = _
|
||||
|
||||
/** Trackers for computing various statistics on the data as it's being written out. */
|
||||
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
|
||||
description.statsTrackers.map(_.newTaskInstance())
|
||||
|
||||
protected def releaseResources(): Unit = {
|
||||
if (currentWriter != null) {
|
||||
try {
|
||||
currentWriter.close()
|
||||
} finally {
|
||||
currentWriter = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Writes a record */
|
||||
def write(record: InternalRow): Unit
|
||||
|
||||
/**
|
||||
* Returns the summary of relative information which
|
||||
* includes the list of partition strings written out. The list of partitions is sent back
|
||||
* to the driver and used to update the catalog. Other information will be sent back to the
|
||||
* driver too and used to e.g. update the metrics in UI.
|
||||
*/
|
||||
def commit(): WriteTaskResult = {
|
||||
releaseResources()
|
||||
val summary = ExecutedWriteSummary(
|
||||
updatedPartitions = updatedPartitions.toSet,
|
||||
stats = statsTrackers.map(_.getFinalStats()))
|
||||
WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
|
||||
}
|
||||
|
||||
def abort(): Unit = {
|
||||
try {
|
||||
releaseResources()
|
||||
} finally {
|
||||
committer.abortTask(taskAttemptContext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** FileFormatWriteTask for empty partitions */
|
||||
class EmptyDirectoryDataWriter(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol
|
||||
) extends FileFormatDataWriter(description, taskAttemptContext, committer) {
|
||||
override def write(record: InternalRow): Unit = {}
|
||||
}
|
||||
|
||||
/** Writes data to a single directory (used for non-dynamic-partition writes). */
|
||||
class SingleDirectoryDataWriter(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol)
|
||||
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
|
||||
private var fileCounter: Int = _
|
||||
private var recordsInFile: Long = _
|
||||
// Initialize currentWriter and statsTrackers
|
||||
newOutputWriter()
|
||||
|
||||
private def newOutputWriter(): Unit = {
|
||||
recordsInFile = 0
|
||||
releaseResources()
|
||||
|
||||
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
|
||||
val currentPath = committer.newTaskTempFile(
|
||||
taskAttemptContext,
|
||||
None,
|
||||
f"-c$fileCounter%03d" + ext)
|
||||
|
||||
currentWriter = description.outputWriterFactory.newInstance(
|
||||
path = currentPath,
|
||||
dataSchema = description.dataColumns.toStructType,
|
||||
context = taskAttemptContext)
|
||||
|
||||
statsTrackers.foreach(_.newFile(currentPath))
|
||||
}
|
||||
|
||||
override def write(record: InternalRow): Unit = {
|
||||
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
|
||||
fileCounter += 1
|
||||
assert(fileCounter < MAX_FILE_COUNTER,
|
||||
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
|
||||
|
||||
newOutputWriter()
|
||||
}
|
||||
|
||||
currentWriter.write(record)
|
||||
statsTrackers.foreach(_.newRow(record))
|
||||
recordsInFile += 1
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes data to using dynamic partition writes, meaning this single function can write to
|
||||
* multiple directories (partitions) or files (bucketing).
|
||||
*/
|
||||
class DynamicPartitionDataWriter(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol)
|
||||
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
|
||||
|
||||
/** Flag saying whether or not the data to be written out is partitioned. */
|
||||
private val isPartitioned = description.partitionColumns.nonEmpty
|
||||
|
||||
/** Flag saying whether or not the data to be written out is bucketed. */
|
||||
private val isBucketed = description.bucketIdExpression.isDefined
|
||||
|
||||
assert(isPartitioned || isBucketed,
|
||||
s"""DynamicPartitionWriteTask should be used for writing out data that's either
|
||||
|partitioned or bucketed. In this case neither is true.
|
||||
|WriteJobDescription: $description
|
||||
""".stripMargin)
|
||||
|
||||
private var fileCounter: Int = _
|
||||
private var recordsInFile: Long = _
|
||||
private var currentPartionValues: Option[UnsafeRow] = None
|
||||
private var currentBucketId: Option[Int] = None
|
||||
|
||||
/** Extracts the partition values out of an input row. */
|
||||
private lazy val getPartitionValues: InternalRow => UnsafeRow = {
|
||||
val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns)
|
||||
row => proj(row)
|
||||
}
|
||||
|
||||
/** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
|
||||
private lazy val partitionPathExpression: Expression = Concat(
|
||||
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
|
||||
val partitionName = ScalaUDF(
|
||||
ExternalCatalogUtils.getPartitionPathString _,
|
||||
StringType,
|
||||
Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))))
|
||||
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
|
||||
})
|
||||
|
||||
/** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
|
||||
* the partition string. */
|
||||
private lazy val getPartitionPath: InternalRow => String = {
|
||||
val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns)
|
||||
row => proj(row).getString(0)
|
||||
}
|
||||
|
||||
/** Given an input row, returns the corresponding `bucketId` */
|
||||
private lazy val getBucketId: InternalRow => Int = {
|
||||
val proj =
|
||||
UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
|
||||
row => proj(row).getInt(0)
|
||||
}
|
||||
|
||||
/** Returns the data columns to be written given an input row */
|
||||
private val getOutputRow =
|
||||
UnsafeProjection.create(description.dataColumns, description.allColumns)
|
||||
|
||||
/**
|
||||
* Opens a new OutputWriter given a partition key and/or a bucket id.
|
||||
* If bucket id is specified, we will append it to the end of the file name, but before the
|
||||
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
|
||||
*
|
||||
* @param partitionValues the partition which all tuples being written by this `OutputWriter`
|
||||
* belong to
|
||||
* @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
|
||||
*/
|
||||
private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = {
|
||||
recordsInFile = 0
|
||||
releaseResources()
|
||||
|
||||
val partDir = partitionValues.map(getPartitionPath(_))
|
||||
partDir.foreach(updatedPartitions.add)
|
||||
|
||||
val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
|
||||
|
||||
// This must be in a form that matches our bucketing format. See BucketingUtils.
|
||||
val ext = f"$bucketIdStr.c$fileCounter%03d" +
|
||||
description.outputWriterFactory.getFileExtension(taskAttemptContext)
|
||||
|
||||
val customPath = partDir.flatMap { dir =>
|
||||
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
|
||||
}
|
||||
val currentPath = if (customPath.isDefined) {
|
||||
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
|
||||
} else {
|
||||
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
|
||||
}
|
||||
|
||||
currentWriter = description.outputWriterFactory.newInstance(
|
||||
path = currentPath,
|
||||
dataSchema = description.dataColumns.toStructType,
|
||||
context = taskAttemptContext)
|
||||
|
||||
statsTrackers.foreach(_.newFile(currentPath))
|
||||
}
|
||||
|
||||
override def write(record: InternalRow): Unit = {
|
||||
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
|
||||
val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
|
||||
|
||||
if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
|
||||
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
|
||||
if (isPartitioned && currentPartionValues != nextPartitionValues) {
|
||||
currentPartionValues = Some(nextPartitionValues.get.copy())
|
||||
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
|
||||
}
|
||||
if (isBucketed) {
|
||||
currentBucketId = nextBucketId
|
||||
statsTrackers.foreach(_.newBucket(currentBucketId.get))
|
||||
}
|
||||
|
||||
fileCounter = 0
|
||||
newOutputWriter(currentPartionValues, currentBucketId)
|
||||
} else if (description.maxRecordsPerFile > 0 &&
|
||||
recordsInFile >= description.maxRecordsPerFile) {
|
||||
// Exceeded the threshold in terms of the number of records per file.
|
||||
// Create a new file by increasing the file counter.
|
||||
fileCounter += 1
|
||||
assert(fileCounter < MAX_FILE_COUNTER,
|
||||
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
|
||||
|
||||
newOutputWriter(currentPartionValues, currentBucketId)
|
||||
}
|
||||
val outputRow = getOutputRow(record)
|
||||
currentWriter.write(outputRow)
|
||||
statsTrackers.foreach(_.newRow(outputRow))
|
||||
recordsInFile += 1
|
||||
}
|
||||
}
|
||||
|
||||
/** A shared job description for all the write tasks. */
|
||||
class WriteJobDescription(
|
||||
val uuid: String, // prevent collision between different (appending) write jobs
|
||||
val serializableHadoopConf: SerializableConfiguration,
|
||||
val outputWriterFactory: OutputWriterFactory,
|
||||
val allColumns: Seq[Attribute],
|
||||
val dataColumns: Seq[Attribute],
|
||||
val partitionColumns: Seq[Attribute],
|
||||
val bucketIdExpression: Option[Expression],
|
||||
val path: String,
|
||||
val customPartitionLocations: Map[TablePartitionSpec, String],
|
||||
val maxRecordsPerFile: Long,
|
||||
val timeZoneId: String,
|
||||
val statsTrackers: Seq[WriteJobStatsTracker])
|
||||
extends Serializable {
|
||||
|
||||
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
|
||||
s"""
|
||||
|All columns: ${allColumns.mkString(", ")}
|
||||
|Partition columns: ${partitionColumns.mkString(", ")}
|
||||
|Data columns: ${dataColumns.mkString(", ")}
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
/** The result of a successful write task. */
|
||||
case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
|
||||
|
||||
/**
|
||||
* Wrapper class for the metrics of writing data out.
|
||||
*
|
||||
* @param updatedPartitions the partitions updated during writing data out. Only valid
|
||||
* for dynamic partition.
|
||||
* @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
|
||||
*/
|
||||
case class ExecutedWriteSummary(
|
||||
updatedPartitions: Set[String],
|
||||
stats: Seq[WriteTaskStats])
|
|
@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
|
|||
|
||||
import java.util.{Date, UUID}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.mapreduce._
|
||||
|
@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
|
|||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
|
||||
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
|
||||
import org.apache.spark.shuffle.FetchFailedException
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
|
||||
import org.apache.spark.sql.catalyst.catalog.BucketSpec
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
|
||||
import org.apache.spark.sql.types.StringType
|
||||
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||
|
||||
|
||||
/** A helper object for writing FileFormat data out to a location. */
|
||||
object FileFormatWriter extends Logging {
|
||||
|
||||
/**
|
||||
* Max number of files a single task writes out due to file size. In most cases the number of
|
||||
* files written should be very small. This is just a safe guard to protect some really bad
|
||||
* settings, e.g. maxRecordsPerFile = 1.
|
||||
*/
|
||||
private val MAX_FILE_COUNTER = 1000 * 1000
|
||||
|
||||
/** Describes how output files should be placed in the filesystem. */
|
||||
case class OutputSpec(
|
||||
outputPath: String,
|
||||
customPartitionLocations: Map[TablePartitionSpec, String],
|
||||
outputColumns: Seq[Attribute])
|
||||
|
||||
/** A shared job description for all the write tasks. */
|
||||
private class WriteJobDescription(
|
||||
val uuid: String, // prevent collision between different (appending) write jobs
|
||||
val serializableHadoopConf: SerializableConfiguration,
|
||||
val outputWriterFactory: OutputWriterFactory,
|
||||
val allColumns: Seq[Attribute],
|
||||
val dataColumns: Seq[Attribute],
|
||||
val partitionColumns: Seq[Attribute],
|
||||
val bucketIdExpression: Option[Expression],
|
||||
val path: String,
|
||||
val customPartitionLocations: Map[TablePartitionSpec, String],
|
||||
val maxRecordsPerFile: Long,
|
||||
val timeZoneId: String,
|
||||
val statsTrackers: Seq[WriteJobStatsTracker])
|
||||
extends Serializable {
|
||||
|
||||
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
|
||||
s"""
|
||||
|All columns: ${allColumns.mkString(", ")}
|
||||
|Partition columns: ${partitionColumns.mkString(", ")}
|
||||
|Data columns: ${dataColumns.mkString(", ")}
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
/** The result of a successful write task. */
|
||||
private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
|
||||
outputPath: String,
|
||||
customPartitionLocations: Map[TablePartitionSpec, String],
|
||||
outputColumns: Seq[Attribute])
|
||||
|
||||
/**
|
||||
* Basic work flow of this command is:
|
||||
|
@ -262,30 +223,27 @@ object FileFormatWriter extends Logging {
|
|||
|
||||
committer.setupTask(taskAttemptContext)
|
||||
|
||||
val writeTask =
|
||||
val dataWriter =
|
||||
if (sparkPartitionId != 0 && !iterator.hasNext) {
|
||||
// In case of empty job, leave first partition to save meta for file format like parquet.
|
||||
new EmptyDirectoryWriteTask(description)
|
||||
new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
|
||||
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
|
||||
new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
|
||||
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
|
||||
} else {
|
||||
new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
|
||||
new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
|
||||
}
|
||||
|
||||
try {
|
||||
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
|
||||
// Execute the task to write rows out and commit the task.
|
||||
val summary = writeTask.execute(iterator)
|
||||
writeTask.releaseResources()
|
||||
WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
|
||||
})(catchBlock = {
|
||||
// If there is an error, release resource and then abort the task
|
||||
try {
|
||||
writeTask.releaseResources()
|
||||
} finally {
|
||||
committer.abortTask(taskAttemptContext)
|
||||
logError(s"Job $jobId aborted.")
|
||||
while (iterator.hasNext) {
|
||||
dataWriter.write(iterator.next())
|
||||
}
|
||||
dataWriter.commit()
|
||||
})(catchBlock = {
|
||||
// If there is an error, abort the task
|
||||
dataWriter.abort()
|
||||
logError(s"Job $jobId aborted.")
|
||||
})
|
||||
} catch {
|
||||
case e: FetchFailedException =>
|
||||
|
@ -302,7 +260,7 @@ object FileFormatWriter extends Logging {
|
|||
private def processStats(
|
||||
statsTrackers: Seq[WriteJobStatsTracker],
|
||||
statsPerTask: Seq[Seq[WriteTaskStats]])
|
||||
: Unit = {
|
||||
: Unit = {
|
||||
|
||||
val numStatsTrackers = statsTrackers.length
|
||||
assert(statsPerTask.forall(_.length == numStatsTrackers),
|
||||
|
@ -321,281 +279,4 @@ object FileFormatWriter extends Logging {
|
|||
case (statsTracker, stats) => statsTracker.processStats(stats)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple trait for writing out data in a single Spark task, without any concerns about how
|
||||
* to commit or abort tasks. Exceptions thrown by the implementation of this trait will
|
||||
* automatically trigger task aborts.
|
||||
*/
|
||||
private trait ExecuteWriteTask {
|
||||
|
||||
/**
|
||||
* Writes data out to files, and then returns the summary of relative information which
|
||||
* includes the list of partition strings written out. The list of partitions is sent back
|
||||
* to the driver and used to update the catalog. Other information will be sent back to the
|
||||
* driver too and used to e.g. update the metrics in UI.
|
||||
*/
|
||||
def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary
|
||||
def releaseResources(): Unit
|
||||
}
|
||||
|
||||
/** ExecuteWriteTask for empty partitions */
|
||||
private class EmptyDirectoryWriteTask(description: WriteJobDescription)
|
||||
extends ExecuteWriteTask {
|
||||
|
||||
val statsTrackers: Seq[WriteTaskStatsTracker] =
|
||||
description.statsTrackers.map(_.newTaskInstance())
|
||||
|
||||
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
|
||||
ExecutedWriteSummary(
|
||||
updatedPartitions = Set.empty,
|
||||
stats = statsTrackers.map(_.getFinalStats()))
|
||||
}
|
||||
|
||||
override def releaseResources(): Unit = {}
|
||||
}
|
||||
|
||||
/** Writes data to a single directory (used for non-dynamic-partition writes). */
|
||||
private class SingleDirectoryWriteTask(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol) extends ExecuteWriteTask {
|
||||
|
||||
private[this] var currentWriter: OutputWriter = _
|
||||
|
||||
val statsTrackers: Seq[WriteTaskStatsTracker] =
|
||||
description.statsTrackers.map(_.newTaskInstance())
|
||||
|
||||
private def newOutputWriter(fileCounter: Int): Unit = {
|
||||
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
|
||||
val currentPath = committer.newTaskTempFile(
|
||||
taskAttemptContext,
|
||||
None,
|
||||
f"-c$fileCounter%03d" + ext)
|
||||
|
||||
currentWriter = description.outputWriterFactory.newInstance(
|
||||
path = currentPath,
|
||||
dataSchema = description.dataColumns.toStructType,
|
||||
context = taskAttemptContext)
|
||||
|
||||
statsTrackers.map(_.newFile(currentPath))
|
||||
}
|
||||
|
||||
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
|
||||
var fileCounter = 0
|
||||
var recordsInFile: Long = 0L
|
||||
newOutputWriter(fileCounter)
|
||||
|
||||
while (iter.hasNext) {
|
||||
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
|
||||
fileCounter += 1
|
||||
assert(fileCounter < MAX_FILE_COUNTER,
|
||||
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
|
||||
|
||||
recordsInFile = 0
|
||||
releaseResources()
|
||||
newOutputWriter(fileCounter)
|
||||
}
|
||||
|
||||
val internalRow = iter.next()
|
||||
currentWriter.write(internalRow)
|
||||
statsTrackers.foreach(_.newRow(internalRow))
|
||||
recordsInFile += 1
|
||||
}
|
||||
releaseResources()
|
||||
ExecutedWriteSummary(
|
||||
updatedPartitions = Set.empty,
|
||||
stats = statsTrackers.map(_.getFinalStats()))
|
||||
}
|
||||
|
||||
override def releaseResources(): Unit = {
|
||||
if (currentWriter != null) {
|
||||
try {
|
||||
currentWriter.close()
|
||||
} finally {
|
||||
currentWriter = null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes data to using dynamic partition writes, meaning this single function can write to
|
||||
* multiple directories (partitions) or files (bucketing).
|
||||
*/
|
||||
private class DynamicPartitionWriteTask(
|
||||
desc: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol) extends ExecuteWriteTask {
|
||||
|
||||
/** Flag saying whether or not the data to be written out is partitioned. */
|
||||
val isPartitioned = desc.partitionColumns.nonEmpty
|
||||
|
||||
/** Flag saying whether or not the data to be written out is bucketed. */
|
||||
val isBucketed = desc.bucketIdExpression.isDefined
|
||||
|
||||
assert(isPartitioned || isBucketed,
|
||||
s"""DynamicPartitionWriteTask should be used for writing out data that's either
|
||||
|partitioned or bucketed. In this case neither is true.
|
||||
|WriteJobDescription: ${desc}
|
||||
""".stripMargin)
|
||||
|
||||
// currentWriter is initialized whenever we see a new key (partitionValues + BucketId)
|
||||
private var currentWriter: OutputWriter = _
|
||||
|
||||
/** Trackers for computing various statistics on the data as it's being written out. */
|
||||
private val statsTrackers: Seq[WriteTaskStatsTracker] =
|
||||
desc.statsTrackers.map(_.newTaskInstance())
|
||||
|
||||
/** Extracts the partition values out of an input row. */
|
||||
private lazy val getPartitionValues: InternalRow => UnsafeRow = {
|
||||
val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns)
|
||||
row => proj(row)
|
||||
}
|
||||
|
||||
/** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
|
||||
private lazy val partitionPathExpression: Expression = Concat(
|
||||
desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
|
||||
val partitionName = ScalaUDF(
|
||||
ExternalCatalogUtils.getPartitionPathString _,
|
||||
StringType,
|
||||
Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId))))
|
||||
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
|
||||
})
|
||||
|
||||
/** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
|
||||
* the partition string. */
|
||||
private lazy val getPartitionPath: InternalRow => String = {
|
||||
val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns)
|
||||
row => proj(row).getString(0)
|
||||
}
|
||||
|
||||
/** Given an input row, returns the corresponding `bucketId` */
|
||||
private lazy val getBucketId: InternalRow => Int = {
|
||||
val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns)
|
||||
row => proj(row).getInt(0)
|
||||
}
|
||||
|
||||
/** Returns the data columns to be written given an input row */
|
||||
private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
|
||||
|
||||
/**
|
||||
* Opens a new OutputWriter given a partition key and/or a bucket id.
|
||||
* If bucket id is specified, we will append it to the end of the file name, but before the
|
||||
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
|
||||
*
|
||||
* @param partitionValues the partition which all tuples being written by this `OutputWriter`
|
||||
* belong to
|
||||
* @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
|
||||
* @param fileCounter the number of files that have been written in the past for this specific
|
||||
* partition. This is used to limit the max number of records written for a
|
||||
* single file. The value should start from 0.
|
||||
* @param updatedPartitions the set of updated partition paths, we should add the new partition
|
||||
* path of this writer to it.
|
||||
*/
|
||||
private def newOutputWriter(
|
||||
partitionValues: Option[InternalRow],
|
||||
bucketId: Option[Int],
|
||||
fileCounter: Int,
|
||||
updatedPartitions: mutable.Set[String]): Unit = {
|
||||
|
||||
val partDir = partitionValues.map(getPartitionPath(_))
|
||||
partDir.foreach(updatedPartitions.add)
|
||||
|
||||
val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
|
||||
|
||||
// This must be in a form that matches our bucketing format. See BucketingUtils.
|
||||
val ext = f"$bucketIdStr.c$fileCounter%03d" +
|
||||
desc.outputWriterFactory.getFileExtension(taskAttemptContext)
|
||||
|
||||
val customPath = partDir.flatMap { dir =>
|
||||
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
|
||||
}
|
||||
val currentPath = if (customPath.isDefined) {
|
||||
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
|
||||
} else {
|
||||
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
|
||||
}
|
||||
|
||||
currentWriter = desc.outputWriterFactory.newInstance(
|
||||
path = currentPath,
|
||||
dataSchema = desc.dataColumns.toStructType,
|
||||
context = taskAttemptContext)
|
||||
|
||||
statsTrackers.foreach(_.newFile(currentPath))
|
||||
}
|
||||
|
||||
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
|
||||
// If anything below fails, we should abort the task.
|
||||
var recordsInFile: Long = 0L
|
||||
var fileCounter = 0
|
||||
val updatedPartitions = mutable.Set[String]()
|
||||
var currentPartionValues: Option[UnsafeRow] = None
|
||||
var currentBucketId: Option[Int] = None
|
||||
|
||||
for (row <- iter) {
|
||||
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
|
||||
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None
|
||||
|
||||
if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
|
||||
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
|
||||
if (isPartitioned && currentPartionValues != nextPartitionValues) {
|
||||
currentPartionValues = Some(nextPartitionValues.get.copy())
|
||||
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
|
||||
}
|
||||
if (isBucketed) {
|
||||
currentBucketId = nextBucketId
|
||||
statsTrackers.foreach(_.newBucket(currentBucketId.get))
|
||||
}
|
||||
|
||||
recordsInFile = 0
|
||||
fileCounter = 0
|
||||
|
||||
releaseResources()
|
||||
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
|
||||
} else if (desc.maxRecordsPerFile > 0 &&
|
||||
recordsInFile >= desc.maxRecordsPerFile) {
|
||||
// Exceeded the threshold in terms of the number of records per file.
|
||||
// Create a new file by increasing the file counter.
|
||||
recordsInFile = 0
|
||||
fileCounter += 1
|
||||
assert(fileCounter < MAX_FILE_COUNTER,
|
||||
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
|
||||
|
||||
releaseResources()
|
||||
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
|
||||
}
|
||||
val outputRow = getOutputRow(row)
|
||||
currentWriter.write(outputRow)
|
||||
statsTrackers.foreach(_.newRow(outputRow))
|
||||
recordsInFile += 1
|
||||
}
|
||||
releaseResources()
|
||||
|
||||
ExecutedWriteSummary(
|
||||
updatedPartitions = updatedPartitions.toSet,
|
||||
stats = statsTrackers.map(_.getFinalStats()))
|
||||
}
|
||||
|
||||
override def releaseResources(): Unit = {
|
||||
if (currentWriter != null) {
|
||||
try {
|
||||
currentWriter.close()
|
||||
} finally {
|
||||
currentWriter = null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper class for the metrics of writing data out.
|
||||
*
|
||||
* @param updatedPartitions the partitions updated during writing data out. Only valid
|
||||
* for dynamic partition.
|
||||
* @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
|
||||
*/
|
||||
case class ExecutedWriteSummary(
|
||||
updatedPartitions: Set[String],
|
||||
stats: Seq[WriteTaskStats])
|
||||
|
|
|
@ -116,7 +116,9 @@ object DataWritingSparkTask extends Logging {
|
|||
|
||||
// write the data and commit this writer.
|
||||
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
|
||||
iter.foreach(dataWriter.write)
|
||||
while (iter.hasNext) {
|
||||
dataWriter.write(iter.next())
|
||||
}
|
||||
|
||||
val msg = if (useCommitCoordinator) {
|
||||
val coordinator = SparkEnv.get.outputCommitCoordinator
|
||||
|
|
Loading…
Reference in a new issue