[SPARK-24330][SQL] Refactor ExecuteWriteTask and Use while in writing files

## What changes were proposed in this pull request?
1. Refactor ExecuteWriteTask in FileFormatWriter to reduce common logic and improve readability.
After the change, callers only need to call `commit()` or `abort` at the end of task.
Also there is less code in `SingleDirectoryWriteTask` and `DynamicPartitionWriteTask`.
Definitions of related classes are moved to a new file, and `ExecuteWriteTask` is renamed to `FileFormatDataWriter`.

2. As per code style guide: https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex , we avoid using `for` for looping in [FileFormatWriter](https://github.com/apache/spark/pull/21381/files#diff-3b69eb0963b68c65cfe8075f8a42e850L536) , or `foreach` in [WriteToDataSourceV2Exec](https://github.com/apache/spark/pull/21381/files#diff-6fbe10db766049a395bae2e785e9d56eL119).
In such critical code path, using `while` is good for performance.

## How was this patch tested?
Existing unit test.
I tried the microbenchmark in https://github.com/apache/spark/pull/21409

| Workload | Before changes(Best/Avg Time(ms)) | After changes(Best/Avg Time(ms)) |
| --- | --- | -- |
|Output Single Int Column|    2018 / 2043   |    2096 / 2236 |
|Output Single Double Column| 1978 / 2043 | 2013 / 2018 |
|Output Int and String Column| 6332 / 6706 | 6162 / 6298 |
|Output Partitions| 4458 / 5094 |  3792 / 4008  |
|Output Buckets|           5695 / 6102 |    5120 / 5154 |

Also a microbenchmark on my laptop for general comparison among while/foreach/for :
```
class Writer {
  var sum = 0L
  def write(l: Long): Unit = sum += l
}

def testWhile(iterator: Iterator[Long]): Long = {
  val w = new Writer
  while (iterator.hasNext) {
    w.write(iterator.next())
  }
  w.sum
}

def testForeach(iterator: Iterator[Long]): Long = {
  val w = new Writer
  iterator.foreach(w.write)
  w.sum
}

def testFor(iterator: Iterator[Long]): Long = {
  val w = new Writer
  for (x <- iterator) {
    w.write(x)
  }
  w.sum
}

val data = 0L to 100000000L
val start = System.nanoTime
(0 to 10).foreach(_ => testWhile(data.iterator))
println("benchmark while: " + (System.nanoTime - start)/1000000)

val start2 = System.nanoTime
(0 to 10).foreach(_ => testForeach(data.iterator))
println("benchmark foreach: " + (System.nanoTime - start2)/1000000)

val start3 = System.nanoTime
(0 to 10).foreach(_ => testForeach(data.iterator))
println("benchmark for: " + (System.nanoTime - start3)/1000000)
```
Benchmark result:
`while`: 15401 ms
`foreach`: 43034 ms
`for`: 41279 ms

Author: Gengliang Wang <gengliang.wang@databricks.com>

Closes #21381 from gengliangwang/refactorExecuteWriteTask.
This commit is contained in:
Gengliang Wang 2018-06-01 10:01:15 +08:00 committed by Wenchen Fan
parent 2c9c8629b7
commit cbaa729132
4 changed files with 334 additions and 338 deletions

View file

@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
/**
* Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]].
* Simple metrics collected during an instance of [[FileFormatDataWriter]].
* These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703).
*/
case class BasicWriteTaskStats(

View file

@ -0,0 +1,313 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.SerializableConfiguration
/**
* Abstract class for writing out data in a single Spark task.
* Exceptions thrown by the implementation of this trait will automatically trigger task aborts.
*/
abstract class FileFormatDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) {
/**
* Max number of files a single task writes out due to file size. In most cases the number of
* files written should be very small. This is just a safe guard to protect some really bad
* settings, e.g. maxRecordsPerFile = 1.
*/
protected val MAX_FILE_COUNTER: Int = 1000 * 1000
protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
protected var currentWriter: OutputWriter = _
/** Trackers for computing various statistics on the data as it's being written out. */
protected val statsTrackers: Seq[WriteTaskStatsTracker] =
description.statsTrackers.map(_.newTaskInstance())
protected def releaseResources(): Unit = {
if (currentWriter != null) {
try {
currentWriter.close()
} finally {
currentWriter = null
}
}
}
/** Writes a record */
def write(record: InternalRow): Unit
/**
* Returns the summary of relative information which
* includes the list of partition strings written out. The list of partitions is sent back
* to the driver and used to update the catalog. Other information will be sent back to the
* driver too and used to e.g. update the metrics in UI.
*/
def commit(): WriteTaskResult = {
releaseResources()
val summary = ExecutedWriteSummary(
updatedPartitions = updatedPartitions.toSet,
stats = statsTrackers.map(_.getFinalStats()))
WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
}
def abort(): Unit = {
try {
releaseResources()
} finally {
committer.abortTask(taskAttemptContext)
}
}
}
/** FileFormatWriteTask for empty partitions */
class EmptyDirectoryDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol
) extends FileFormatDataWriter(description, taskAttemptContext, committer) {
override def write(record: InternalRow): Unit = {}
}
/** Writes data to a single directory (used for non-dynamic-partition writes). */
class SingleDirectoryDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
private var fileCounter: Int = _
private var recordsInFile: Long = _
// Initialize currentWriter and statsTrackers
newOutputWriter()
private def newOutputWriter(): Unit = {
recordsInFile = 0
releaseResources()
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
val currentPath = committer.newTaskTempFile(
taskAttemptContext,
None,
f"-c$fileCounter%03d" + ext)
currentWriter = description.outputWriterFactory.newInstance(
path = currentPath,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
statsTrackers.foreach(_.newFile(currentPath))
}
override def write(record: InternalRow): Unit = {
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
newOutputWriter()
}
currentWriter.write(record)
statsTrackers.foreach(_.newRow(record))
recordsInFile += 1
}
}
/**
* Writes data to using dynamic partition writes, meaning this single function can write to
* multiple directories (partitions) or files (bucketing).
*/
class DynamicPartitionDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol)
extends FileFormatDataWriter(description, taskAttemptContext, committer) {
/** Flag saying whether or not the data to be written out is partitioned. */
private val isPartitioned = description.partitionColumns.nonEmpty
/** Flag saying whether or not the data to be written out is bucketed. */
private val isBucketed = description.bucketIdExpression.isDefined
assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's either
|partitioned or bucketed. In this case neither is true.
|WriteJobDescription: $description
""".stripMargin)
private var fileCounter: Int = _
private var recordsInFile: Long = _
private var currentPartionValues: Option[UnsafeRow] = None
private var currentBucketId: Option[Int] = None
/** Extracts the partition values out of an input row. */
private lazy val getPartitionValues: InternalRow => UnsafeRow = {
val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns)
row => proj(row)
}
/** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
private lazy val partitionPathExpression: Expression = Concat(
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
val partitionName = ScalaUDF(
ExternalCatalogUtils.getPartitionPathString _,
StringType,
Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))))
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
})
/** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
* the partition string. */
private lazy val getPartitionPath: InternalRow => String = {
val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns)
row => proj(row).getString(0)
}
/** Given an input row, returns the corresponding `bucketId` */
private lazy val getBucketId: InternalRow => Int = {
val proj =
UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
row => proj(row).getInt(0)
}
/** Returns the data columns to be written given an input row */
private val getOutputRow =
UnsafeProjection.create(description.dataColumns, description.allColumns)
/**
* Opens a new OutputWriter given a partition key and/or a bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
* @param partitionValues the partition which all tuples being written by this `OutputWriter`
* belong to
* @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
*/
private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = {
recordsInFile = 0
releaseResources()
val partDir = partitionValues.map(getPartitionPath(_))
partDir.foreach(updatedPartitions.add)
val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketIdStr.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext)
val customPath = partDir.flatMap { dir =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
}
val currentPath = if (customPath.isDefined) {
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
currentWriter = description.outputWriterFactory.newInstance(
path = currentPath,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
statsTrackers.foreach(_.newFile(currentPath))
}
override def write(record: InternalRow): Unit = {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
statsTrackers.foreach(_.newBucket(currentBucketId.get))
}
fileCounter = 0
newOutputWriter(currentPartionValues, currentBucketId)
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
newOutputWriter(currentPartionValues, currentBucketId)
}
val outputRow = getOutputRow(record)
currentWriter.write(outputRow)
statsTrackers.foreach(_.newRow(outputRow))
recordsInFile += 1
}
}
/** A shared job description for all the write tasks. */
class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write jobs
val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
val timeZoneId: String,
val statsTrackers: Seq[WriteJobStatsTracker])
extends Serializable {
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
|Partition columns: ${partitionColumns.mkString(", ")}
|Data columns: ${dataColumns.mkString(", ")}
""".stripMargin)
}
/** The result of a successful write task. */
case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
/**
* Wrapper class for the metrics of writing data out.
*
* @param updatedPartitions the partitions updated during writing data out. Only valid
* for dynamic partition.
* @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
*/
case class ExecutedWriteSummary(
updatedPartitions: Set[String],
stats: Seq[WriteTaskStats])

View file

@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
import scala.collection.mutable
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
/**
* Max number of files a single task writes out due to file size. In most cases the number of
* files written should be very small. This is just a safe guard to protect some really bad
* settings, e.g. maxRecordsPerFile = 1.
*/
private val MAX_FILE_COUNTER = 1000 * 1000
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
outputPath: String,
customPartitionLocations: Map[TablePartitionSpec, String],
outputColumns: Seq[Attribute])
/** A shared job description for all the write tasks. */
private class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write jobs
val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long,
val timeZoneId: String,
val statsTrackers: Seq[WriteJobStatsTracker])
extends Serializable {
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
|Partition columns: ${partitionColumns.mkString(", ")}
|Data columns: ${dataColumns.mkString(", ")}
""".stripMargin)
}
/** The result of a successful write task. */
private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
outputPath: String,
customPartitionLocations: Map[TablePartitionSpec, String],
outputColumns: Seq[Attribute])
/**
* Basic work flow of this command is:
@ -262,30 +223,27 @@ object FileFormatWriter extends Logging {
committer.setupTask(taskAttemptContext)
val writeTask =
val dataWriter =
if (sparkPartitionId != 0 && !iterator.hasNext) {
// In case of empty job, leave first partition to save meta for file format like parquet.
new EmptyDirectoryWriteTask(description)
new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
} else {
new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
}
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out and commit the task.
val summary = writeTask.execute(iterator)
writeTask.releaseResources()
WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {
writeTask.releaseResources()
} finally {
committer.abortTask(taskAttemptContext)
logError(s"Job $jobId aborted.")
while (iterator.hasNext) {
dataWriter.write(iterator.next())
}
dataWriter.commit()
})(catchBlock = {
// If there is an error, abort the task
dataWriter.abort()
logError(s"Job $jobId aborted.")
})
} catch {
case e: FetchFailedException =>
@ -302,7 +260,7 @@ object FileFormatWriter extends Logging {
private def processStats(
statsTrackers: Seq[WriteJobStatsTracker],
statsPerTask: Seq[Seq[WriteTaskStats]])
: Unit = {
: Unit = {
val numStatsTrackers = statsTrackers.length
assert(statsPerTask.forall(_.length == numStatsTrackers),
@ -321,281 +279,4 @@ object FileFormatWriter extends Logging {
case (statsTracker, stats) => statsTracker.processStats(stats)
}
}
/**
* A simple trait for writing out data in a single Spark task, without any concerns about how
* to commit or abort tasks. Exceptions thrown by the implementation of this trait will
* automatically trigger task aborts.
*/
private trait ExecuteWriteTask {
/**
* Writes data out to files, and then returns the summary of relative information which
* includes the list of partition strings written out. The list of partitions is sent back
* to the driver and used to update the catalog. Other information will be sent back to the
* driver too and used to e.g. update the metrics in UI.
*/
def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary
def releaseResources(): Unit
}
/** ExecuteWriteTask for empty partitions */
private class EmptyDirectoryWriteTask(description: WriteJobDescription)
extends ExecuteWriteTask {
val statsTrackers: Seq[WriteTaskStatsTracker] =
description.statsTrackers.map(_.newTaskInstance())
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
ExecutedWriteSummary(
updatedPartitions = Set.empty,
stats = statsTrackers.map(_.getFinalStats()))
}
override def releaseResources(): Unit = {}
}
/** Writes data to a single directory (used for non-dynamic-partition writes). */
private class SingleDirectoryWriteTask(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {
private[this] var currentWriter: OutputWriter = _
val statsTrackers: Seq[WriteTaskStatsTracker] =
description.statsTrackers.map(_.newTaskInstance())
private def newOutputWriter(fileCounter: Int): Unit = {
val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
val currentPath = committer.newTaskTempFile(
taskAttemptContext,
None,
f"-c$fileCounter%03d" + ext)
currentWriter = description.outputWriterFactory.newInstance(
path = currentPath,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
statsTrackers.map(_.newFile(currentPath))
}
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
var fileCounter = 0
var recordsInFile: Long = 0L
newOutputWriter(fileCounter)
while (iter.hasNext) {
if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
recordsInFile = 0
releaseResources()
newOutputWriter(fileCounter)
}
val internalRow = iter.next()
currentWriter.write(internalRow)
statsTrackers.foreach(_.newRow(internalRow))
recordsInFile += 1
}
releaseResources()
ExecutedWriteSummary(
updatedPartitions = Set.empty,
stats = statsTrackers.map(_.getFinalStats()))
}
override def releaseResources(): Unit = {
if (currentWriter != null) {
try {
currentWriter.close()
} finally {
currentWriter = null
}
}
}
}
/**
* Writes data to using dynamic partition writes, meaning this single function can write to
* multiple directories (partitions) or files (bucketing).
*/
private class DynamicPartitionWriteTask(
desc: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {
/** Flag saying whether or not the data to be written out is partitioned. */
val isPartitioned = desc.partitionColumns.nonEmpty
/** Flag saying whether or not the data to be written out is bucketed. */
val isBucketed = desc.bucketIdExpression.isDefined
assert(isPartitioned || isBucketed,
s"""DynamicPartitionWriteTask should be used for writing out data that's either
|partitioned or bucketed. In this case neither is true.
|WriteJobDescription: ${desc}
""".stripMargin)
// currentWriter is initialized whenever we see a new key (partitionValues + BucketId)
private var currentWriter: OutputWriter = _
/** Trackers for computing various statistics on the data as it's being written out. */
private val statsTrackers: Seq[WriteTaskStatsTracker] =
desc.statsTrackers.map(_.newTaskInstance())
/** Extracts the partition values out of an input row. */
private lazy val getPartitionValues: InternalRow => UnsafeRow = {
val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns)
row => proj(row)
}
/** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
private lazy val partitionPathExpression: Expression = Concat(
desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
val partitionName = ScalaUDF(
ExternalCatalogUtils.getPartitionPathString _,
StringType,
Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId))))
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
})
/** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
* the partition string. */
private lazy val getPartitionPath: InternalRow => String = {
val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns)
row => proj(row).getString(0)
}
/** Given an input row, returns the corresponding `bucketId` */
private lazy val getBucketId: InternalRow => Int = {
val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns)
row => proj(row).getInt(0)
}
/** Returns the data columns to be written given an input row */
private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
/**
* Opens a new OutputWriter given a partition key and/or a bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*
* @param partitionValues the partition which all tuples being written by this `OutputWriter`
* belong to
* @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
* @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a
* single file. The value should start from 0.
* @param updatedPartitions the set of updated partition paths, we should add the new partition
* path of this writer to it.
*/
private def newOutputWriter(
partitionValues: Option[InternalRow],
bucketId: Option[Int],
fileCounter: Int,
updatedPartitions: mutable.Set[String]): Unit = {
val partDir = partitionValues.map(getPartitionPath(_))
partDir.foreach(updatedPartitions.add)
val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
// This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketIdStr.c$fileCounter%03d" +
desc.outputWriterFactory.getFileExtension(taskAttemptContext)
val customPath = partDir.flatMap { dir =>
desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
}
val currentPath = if (customPath.isDefined) {
committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
currentWriter = desc.outputWriterFactory.newInstance(
path = currentPath,
dataSchema = desc.dataColumns.toStructType,
context = taskAttemptContext)
statsTrackers.foreach(_.newFile(currentPath))
}
override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
val updatedPartitions = mutable.Set[String]()
var currentPartionValues: Option[UnsafeRow] = None
var currentBucketId: Option[Int] = None
for (row <- iter) {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None
if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
statsTrackers.foreach(_.newBucket(currentBucketId.get))
}
recordsInFile = 0
fileCounter = 0
releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
} else if (desc.maxRecordsPerFile > 0 &&
recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
}
val outputRow = getOutputRow(row)
currentWriter.write(outputRow)
statsTrackers.foreach(_.newRow(outputRow))
recordsInFile += 1
}
releaseResources()
ExecutedWriteSummary(
updatedPartitions = updatedPartitions.toSet,
stats = statsTrackers.map(_.getFinalStats()))
}
override def releaseResources(): Unit = {
if (currentWriter != null) {
try {
currentWriter.close()
} finally {
currentWriter = null
}
}
}
}
}
/**
* Wrapper class for the metrics of writing data out.
*
* @param updatedPartitions the partitions updated during writing data out. Only valid
* for dynamic partition.
* @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
*/
case class ExecutedWriteSummary(
updatedPartitions: Set[String],
stats: Seq[WriteTaskStats])

View file

@ -116,7 +116,9 @@ object DataWritingSparkTask extends Logging {
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
iter.foreach(dataWriter.write)
while (iter.hasNext) {
dataWriter.write(iter.next())
}
val msg = if (useCommitCoordinator) {
val coordinator = SparkEnv.get.outputCommitCoordinator