diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 3b064a5bc4..7e12bbb212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer( extends BaseWriterContainer(relation, job, isAppend) { def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - executorSideSetup(taskContext) - val configuration = taskAttemptContext.getConfiguration - configuration.set("spark.sql.sources.output.path", outputPath) - var writer = newOutputWriter(getWorkPath) - writer.initConverter(dataSchema) + if (iterator.hasNext) { + executorSideSetup(taskContext) + val configuration = taskAttemptContext.getConfiguration + configuration.set("spark.sql.sources.output.path", outputPath) + var writer = newOutputWriter(getWorkPath) + writer.initConverter(dataSchema) - // If anything below fails, we should abort the task. - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - commitTask() - }(catchBlock = abortTask()) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } - - def commitTask(): Unit = { + // If anything below fails, we should abort the task. try { - if (writer != null) { - writer.close() - writer = null - } - super.commitTask() + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + commitTask() + }(catchBlock = abortTask()) } catch { - case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and - // will cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } - } - def abortTask(): Unit = { - try { - if (writer != null) { - writer.close() + def commitTask(): Unit = { + try { + if (writer != null) { + writer.close() + writer = null + } + super.commitTask() + } catch { + case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and + // will cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + if (writer != null) { + writer.close() + } + } finally { + super.abortTask() } - } finally { - super.abortTask() } } } @@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer( } def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - executorSideSetup(taskContext) + if (iterator.hasNext) { + executorSideSetup(taskContext) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val sortingExpressions: Seq[Expression] = + partitionColumns ++ bucketIdExpression ++ sortColumns + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) - - // Returns the partition path given a partition key. - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) }) - } - val sortedIterator = sorter.sortedIterator() + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) - // If anything below fails, we should abort the task. - var currentWriter: OutputWriter = null - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + + val sortedIterator = sorter.sortedIterator() + + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - commitTask() - }(catchBlock = { - if (currentWriter != null) { - currentWriter.close() - } - abortTask() - }) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) + commitTask() + }(catchBlock = { + if (currentWriter != null) { + currentWriter.close() + } + abortTask() + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 794fe264ea..706fdbc260 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer( // this function is executed on executor side def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() - executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + if (iterator.hasNext) { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - iterator.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) - i += 1 + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + writer.write(serializer.serialize(outputData, standardOI)) } - writer.write(serializer.serialize(outputData, standardOI)) - } - close() + close() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 82d3e49f92..883cdac110 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException -import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql( s""" - |CREATE TABLE table_with_partition(c1 string) - |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) sql( """ |INSERT OVERWRITE TABLE table_with_partition @@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } + test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + val testDataset = hiveContext.sparkContext.parallelize( + (1 to 2).map(i => TestData(i, i.toString))).toDF() + testDataset.createOrReplaceTempView("testDataset") + + val tmpDir = Utils.createTempDir() + sql( + s""" + |CREATE TABLE table1(key int,value string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table1 + |SELECT count(key), value FROM testDataset GROUP BY value + """.stripMargin) + + val overwrittenFiles = tmpDir.listFiles() + .filter(f => f.isFile && !f.getName.endsWith(".crc")) + .sortBy(_.getName) + val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0) + + assert(overwrittenFiles === overwrittenFilesWithoutEmpty) + + sql("DROP TABLE table1") + } + } + test("Reject partitioning that does not match table") { withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index f4d63334b6..78d2dc28d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.DataSourceScanExec -import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } } + + test("SPARK-10216: Avoid empty files during overwriting with group by query") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + withTempPath { path => + val df = spark.range(0, 5) + val groupedDF = df.groupBy("id").count() + groupedDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .save(path.getCanonicalPath) + + val overwrittenFiles = path.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + .sortBy(_.getName) + val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0) + + assert(overwrittenFiles === overwrittenFilesWithoutEmpty) + } + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when