[SPARK-10216][SQL] Avoid creating empty files during overwriting with group by query
## What changes were proposed in this pull request? Currently, `INSERT INTO` with `GROUP BY` query tries to make at least 200 files (default value of `spark.sql.shuffle.partition`), which results in lots of empty files. This PR makes it avoid creating empty files during overwriting into Hive table and in internal data sources with group by query. This checks whether the given partition has data in it or not and creates/writes file only when it actually has data. ## How was this patch tested? Unittests in `InsertIntoHiveTableSuite` and `HadoopFsRelationTest`. Closes #8411 Author: hyukjinkwon <gurwls223@gmail.com> Author: Keuntae Park <sirpkt@apache.org> Closes #12855 from HyukjinKwon/pr/8411.
This commit is contained in:
parent
20a89478e1
commit
8d05a7a98b
|
@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer(
|
||||||
extends BaseWriterContainer(relation, job, isAppend) {
|
extends BaseWriterContainer(relation, job, isAppend) {
|
||||||
|
|
||||||
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
||||||
executorSideSetup(taskContext)
|
if (iterator.hasNext) {
|
||||||
val configuration = taskAttemptContext.getConfiguration
|
executorSideSetup(taskContext)
|
||||||
configuration.set("spark.sql.sources.output.path", outputPath)
|
val configuration = taskAttemptContext.getConfiguration
|
||||||
var writer = newOutputWriter(getWorkPath)
|
configuration.set("spark.sql.sources.output.path", outputPath)
|
||||||
writer.initConverter(dataSchema)
|
var writer = newOutputWriter(getWorkPath)
|
||||||
|
writer.initConverter(dataSchema)
|
||||||
|
|
||||||
// If anything below fails, we should abort the task.
|
// If anything below fails, we should abort the task.
|
||||||
try {
|
|
||||||
Utils.tryWithSafeFinallyAndFailureCallbacks {
|
|
||||||
while (iterator.hasNext) {
|
|
||||||
val internalRow = iterator.next()
|
|
||||||
writer.writeInternal(internalRow)
|
|
||||||
}
|
|
||||||
commitTask()
|
|
||||||
}(catchBlock = abortTask())
|
|
||||||
} catch {
|
|
||||||
case t: Throwable =>
|
|
||||||
throw new SparkException("Task failed while writing rows", t)
|
|
||||||
}
|
|
||||||
|
|
||||||
def commitTask(): Unit = {
|
|
||||||
try {
|
try {
|
||||||
if (writer != null) {
|
Utils.tryWithSafeFinallyAndFailureCallbacks {
|
||||||
writer.close()
|
while (iterator.hasNext) {
|
||||||
writer = null
|
val internalRow = iterator.next()
|
||||||
}
|
writer.writeInternal(internalRow)
|
||||||
super.commitTask()
|
}
|
||||||
|
commitTask()
|
||||||
|
}(catchBlock = abortTask())
|
||||||
} catch {
|
} catch {
|
||||||
case cause: Throwable =>
|
case t: Throwable =>
|
||||||
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
|
throw new SparkException("Task failed while writing rows", t)
|
||||||
// will cause `abortTask()` to be invoked.
|
|
||||||
throw new RuntimeException("Failed to commit task", cause)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
def abortTask(): Unit = {
|
def commitTask(): Unit = {
|
||||||
try {
|
try {
|
||||||
if (writer != null) {
|
if (writer != null) {
|
||||||
writer.close()
|
writer.close()
|
||||||
|
writer = null
|
||||||
|
}
|
||||||
|
super.commitTask()
|
||||||
|
} catch {
|
||||||
|
case cause: Throwable =>
|
||||||
|
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
|
||||||
|
// will cause `abortTask()` to be invoked.
|
||||||
|
throw new RuntimeException("Failed to commit task", cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def abortTask(): Unit = {
|
||||||
|
try {
|
||||||
|
if (writer != null) {
|
||||||
|
writer.close()
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
super.abortTask()
|
||||||
}
|
}
|
||||||
} finally {
|
|
||||||
super.abortTask()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer(
|
||||||
}
|
}
|
||||||
|
|
||||||
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
||||||
executorSideSetup(taskContext)
|
if (iterator.hasNext) {
|
||||||
|
executorSideSetup(taskContext)
|
||||||
|
|
||||||
// We should first sort by partition columns, then bucket id, and finally sorting columns.
|
// We should first sort by partition columns, then bucket id, and finally sorting columns.
|
||||||
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
|
val sortingExpressions: Seq[Expression] =
|
||||||
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
|
partitionColumns ++ bucketIdExpression ++ sortColumns
|
||||||
|
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
|
||||||
|
|
||||||
val sortingKeySchema = StructType(sortingExpressions.map {
|
val sortingKeySchema = StructType(sortingExpressions.map {
|
||||||
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
|
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
|
||||||
// The sorting expressions are all `Attribute` except bucket id.
|
// The sorting expressions are all `Attribute` except bucket id.
|
||||||
case _ => StructField("bucketId", IntegerType, nullable = false)
|
case _ => StructField("bucketId", IntegerType, nullable = false)
|
||||||
})
|
|
||||||
|
|
||||||
// Returns the data columns to be written given an input row
|
|
||||||
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
|
|
||||||
|
|
||||||
// Returns the partition path given a partition key.
|
|
||||||
val getPartitionString =
|
|
||||||
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
|
|
||||||
|
|
||||||
// Sorts the data before write, so that we only need one writer at the same time.
|
|
||||||
// TODO: inject a local sort operator in planning.
|
|
||||||
val sorter = new UnsafeKVExternalSorter(
|
|
||||||
sortingKeySchema,
|
|
||||||
StructType.fromAttributes(dataColumns),
|
|
||||||
SparkEnv.get.blockManager,
|
|
||||||
SparkEnv.get.serializerManager,
|
|
||||||
TaskContext.get().taskMemoryManager().pageSizeBytes)
|
|
||||||
|
|
||||||
while (iterator.hasNext) {
|
|
||||||
val currentRow = iterator.next()
|
|
||||||
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
|
|
||||||
}
|
|
||||||
logInfo(s"Sorting complete. Writing out partition files one at a time.")
|
|
||||||
|
|
||||||
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
|
|
||||||
identity
|
|
||||||
} else {
|
|
||||||
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
|
|
||||||
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
val sortedIterator = sorter.sortedIterator()
|
// Returns the data columns to be written given an input row
|
||||||
|
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
|
||||||
|
|
||||||
// If anything below fails, we should abort the task.
|
// Returns the partition path given a partition key.
|
||||||
var currentWriter: OutputWriter = null
|
val getPartitionString =
|
||||||
try {
|
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
|
||||||
Utils.tryWithSafeFinallyAndFailureCallbacks {
|
|
||||||
var currentKey: UnsafeRow = null
|
// Sorts the data before write, so that we only need one writer at the same time.
|
||||||
while (sortedIterator.next()) {
|
// TODO: inject a local sort operator in planning.
|
||||||
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
|
val sorter = new UnsafeKVExternalSorter(
|
||||||
if (currentKey != nextKey) {
|
sortingKeySchema,
|
||||||
if (currentWriter != null) {
|
StructType.fromAttributes(dataColumns),
|
||||||
currentWriter.close()
|
SparkEnv.get.blockManager,
|
||||||
currentWriter = null
|
SparkEnv.get.serializerManager,
|
||||||
|
TaskContext.get().taskMemoryManager().pageSizeBytes)
|
||||||
|
|
||||||
|
while (iterator.hasNext) {
|
||||||
|
val currentRow = iterator.next()
|
||||||
|
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
|
||||||
|
}
|
||||||
|
logInfo(s"Sorting complete. Writing out partition files one at a time.")
|
||||||
|
|
||||||
|
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
|
||||||
|
identity
|
||||||
|
} else {
|
||||||
|
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
|
||||||
|
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
val sortedIterator = sorter.sortedIterator()
|
||||||
|
|
||||||
|
// If anything below fails, we should abort the task.
|
||||||
|
var currentWriter: OutputWriter = null
|
||||||
|
try {
|
||||||
|
Utils.tryWithSafeFinallyAndFailureCallbacks {
|
||||||
|
var currentKey: UnsafeRow = null
|
||||||
|
while (sortedIterator.next()) {
|
||||||
|
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
|
||||||
|
if (currentKey != nextKey) {
|
||||||
|
if (currentWriter != null) {
|
||||||
|
currentWriter.close()
|
||||||
|
currentWriter = null
|
||||||
|
}
|
||||||
|
currentKey = nextKey.copy()
|
||||||
|
logDebug(s"Writing partition: $currentKey")
|
||||||
|
|
||||||
|
currentWriter = newOutputWriter(currentKey, getPartitionString)
|
||||||
}
|
}
|
||||||
currentKey = nextKey.copy()
|
currentWriter.writeInternal(sortedIterator.getValue)
|
||||||
logDebug(s"Writing partition: $currentKey")
|
}
|
||||||
|
if (currentWriter != null) {
|
||||||
currentWriter = newOutputWriter(currentKey, getPartitionString)
|
currentWriter.close()
|
||||||
|
currentWriter = null
|
||||||
}
|
}
|
||||||
currentWriter.writeInternal(sortedIterator.getValue)
|
|
||||||
}
|
|
||||||
if (currentWriter != null) {
|
|
||||||
currentWriter.close()
|
|
||||||
currentWriter = null
|
|
||||||
}
|
|
||||||
|
|
||||||
commitTask()
|
commitTask()
|
||||||
}(catchBlock = {
|
}(catchBlock = {
|
||||||
if (currentWriter != null) {
|
if (currentWriter != null) {
|
||||||
currentWriter.close()
|
currentWriter.close()
|
||||||
}
|
}
|
||||||
abortTask()
|
abortTask()
|
||||||
})
|
})
|
||||||
} catch {
|
} catch {
|
||||||
case t: Throwable =>
|
case t: Throwable =>
|
||||||
throw new SparkException("Task failed while writing rows", t)
|
throw new SparkException("Task failed while writing rows", t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer(
|
||||||
|
|
||||||
// this function is executed on executor side
|
// this function is executed on executor side
|
||||||
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
|
||||||
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
|
if (iterator.hasNext) {
|
||||||
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
|
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
|
||||||
|
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
|
||||||
|
|
||||||
iterator.foreach { row =>
|
iterator.foreach { row =>
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < fieldOIs.length) {
|
while (i < fieldOIs.length) {
|
||||||
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
|
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
|
||||||
i += 1
|
i += 1
|
||||||
|
}
|
||||||
|
writer.write(serializer.serialize(outputData, standardOI))
|
||||||
}
|
}
|
||||||
writer.write(serializer.serialize(outputData, standardOI))
|
|
||||||
}
|
|
||||||
|
|
||||||
close()
|
close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
|
||||||
|
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
import org.apache.hadoop.hive.conf.HiveConf
|
|
||||||
import org.scalatest.BeforeAndAfter
|
import org.scalatest.BeforeAndAfter
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.sql.{QueryTest, _}
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
|
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
|
||||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||||
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.test.SQLTestUtils
|
import org.apache.spark.sql.test.SQLTestUtils
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
|
||||||
|
|
||||||
sql(
|
sql(
|
||||||
s"""
|
s"""
|
||||||
|CREATE TABLE table_with_partition(c1 string)
|
|CREATE TABLE table_with_partition(c1 string)
|
||||||
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|
||||||
|location '${tmpDir.toURI.toString}'
|
|location '${tmpDir.toURI.toString}'
|
||||||
""".stripMargin)
|
""".stripMargin)
|
||||||
sql(
|
sql(
|
||||||
"""
|
"""
|
||||||
|INSERT OVERWRITE TABLE table_with_partition
|
|INSERT OVERWRITE TABLE table_with_partition
|
||||||
|
@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
|
||||||
sql("DROP TABLE hiveTableWithStructValue")
|
sql("DROP TABLE hiveTableWithStructValue")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
|
||||||
|
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
|
||||||
|
val testDataset = hiveContext.sparkContext.parallelize(
|
||||||
|
(1 to 2).map(i => TestData(i, i.toString))).toDF()
|
||||||
|
testDataset.createOrReplaceTempView("testDataset")
|
||||||
|
|
||||||
|
val tmpDir = Utils.createTempDir()
|
||||||
|
sql(
|
||||||
|
s"""
|
||||||
|
|CREATE TABLE table1(key int,value string)
|
||||||
|
|location '${tmpDir.toURI.toString}'
|
||||||
|
""".stripMargin)
|
||||||
|
sql(
|
||||||
|
"""
|
||||||
|
|INSERT OVERWRITE TABLE table1
|
||||||
|
|SELECT count(key), value FROM testDataset GROUP BY value
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
val overwrittenFiles = tmpDir.listFiles()
|
||||||
|
.filter(f => f.isFile && !f.getName.endsWith(".crc"))
|
||||||
|
.sortBy(_.getName)
|
||||||
|
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
|
||||||
|
|
||||||
|
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
|
||||||
|
|
||||||
|
sql("DROP TABLE table1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("Reject partitioning that does not match table") {
|
test("Reject partitioning that does not match table") {
|
||||||
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
|
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
|
||||||
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
|
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
|
||||||
import org.apache.spark.deploy.SparkHadoopUtil
|
import org.apache.spark.deploy.SparkHadoopUtil
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.execution.DataSourceScanExec
|
import org.apache.spark.sql.execution.DataSourceScanExec
|
||||||
import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
|
import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
|
||||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.test.SQLTestUtils
|
import org.apache.spark.sql.test.SQLTestUtils
|
||||||
|
@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-10216: Avoid empty files during overwriting with group by query") {
|
||||||
|
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
|
||||||
|
withTempPath { path =>
|
||||||
|
val df = spark.range(0, 5)
|
||||||
|
val groupedDF = df.groupBy("id").count()
|
||||||
|
groupedDF.write
|
||||||
|
.format(dataSourceName)
|
||||||
|
.mode(SaveMode.Overwrite)
|
||||||
|
.save(path.getCanonicalPath)
|
||||||
|
|
||||||
|
val overwrittenFiles = path.listFiles()
|
||||||
|
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
|
||||||
|
.sortBy(_.getName)
|
||||||
|
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
|
||||||
|
|
||||||
|
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This class is used to test SPARK-8578. We should not use any custom output committer when
|
// This class is used to test SPARK-8578. We should not use any custom output committer when
|
||||||
|
|
Loading…
Reference in a new issue