diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index f6cc8116c6..be7973b9d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -22,7 +22,7 @@ import java.util.UUID import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.spark.internal.Logging @@ -89,9 +89,7 @@ class ManifestFileCommitProtocol(jobId: String, path: String) try { val fs = path.getFileSystem(jobContext.getConfiguration) // this is to make sure the file can be seen from driver as well - if (fs.exists(path)) { - fs.delete(path, false) - } + deleteIfExists(fs, path) } catch { case e: IOException => logWarning(s"Fail to remove temporary file $path, continue removing next.", e) @@ -139,7 +137,14 @@ class ManifestFileCommitProtocol(jobId: String, path: String) if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) val statuses: Seq[SinkFileStatus] = - addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))) + addedFiles.flatMap { f => + val path = new Path(f) + if (fs.exists(path)) { + Some(SinkFileStatus(fs.getFileStatus(path))) + } else { + None + } + } new TaskCommitMessage(statuses) } else { new TaskCommitMessage(Seq.empty[SinkFileStatus]) @@ -150,7 +155,13 @@ class ManifestFileCommitProtocol(jobId: String, path: String) // best effort cleanup of incomplete files if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) - addedFiles.foreach { file => fs.delete(new Path(file), false) } + addedFiles.foreach { file => deleteIfExists(fs, new Path(file)) } + } + } + + private def deleteIfExists(fs: FileSystem, path: Path, recursive: Boolean = false): Unit = { + if (fs.exists(path)) { + fs.delete(path, recursive) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 0b885c8429..f04da8bfc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -600,11 +600,61 @@ class FileStreamSinkV1Suite extends FileStreamSinkSuite { } class FileStreamSinkV2Suite extends FileStreamSinkSuite { + import testImplicits._ + override protected def sparkConf: SparkConf = super .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") + test("SPARK-29999 Handle FileStreamSink metadata correctly for empty partition") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + var query: StreamingQuery = null + try { + // repartition to more than the input to leave empty partitions + query = + df.repartition(10) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val fs = new Path(outputDir.getCanonicalPath).getFileSystem( + spark.sessionState.newHadoopConf()) + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, + outputDir.getCanonicalPath) + + val allFiles = sinkLog.allFiles() + // only files from non-empty partition should be logged + assert(allFiles.length < 10) + assert(allFiles.forall(file => fs.exists(new Path(file.path)))) + + // the query should be able to read all rows correctly with metadata log + val outputDf = spark.read.format(format).load(outputDir.getCanonicalPath) + .selectExpr("CAST(value AS INT)").as[Int] + checkDatasetUnorderly(outputDf, 1, 2, 3, 4, 5) + } + } + } + } + override def checkQueryExecution(df: DataFrame): Unit = { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred