[SPARK-26673][SQL] File source V2 writes: create framework and migrate ORC
## What changes were proposed in this pull request? Create a framework for write path of File Source V2. Also, migrate write path of ORC to V2. Supported: * Write to file as Dataframe Not Supported: * Partitioning, which is still under development in the data source V2 project. * Bucketing, which is still under development in the data source V2 project. * Catalog. ## How was this patch tested? Unit test Closes #23601 from gengliangwang/orc_write. Authored-by: Gengliang Wang <gengliang.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b3b62ba303
commit
df4c53e44b
|
@ -1440,6 +1440,14 @@ object SQLConf {
|
|||
.stringConf
|
||||
.createWithDefault("")
|
||||
|
||||
val USE_V1_SOURCE_WRITER_LIST = buildConf("spark.sql.sources.write.useV1SourceList")
|
||||
.internal()
|
||||
.doc("A comma-separated list of data source short names or fully qualified data source" +
|
||||
" register class names for which data source V2 write paths are disabled. Writes from these" +
|
||||
" sources will fall back to the V1 sources.")
|
||||
.stringConf
|
||||
.createWithDefault("")
|
||||
|
||||
val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
|
||||
.doc("A comma-separated list of fully qualified data source register class names for which" +
|
||||
" StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.")
|
||||
|
@ -2026,6 +2034,8 @@ class SQLConf extends Serializable with Logging {
|
|||
|
||||
def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST)
|
||||
|
||||
def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST)
|
||||
|
||||
def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)
|
||||
|
||||
def disabledV2StreamingMicroBatchReaders: String =
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable,
|
|||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.execution.command.DDLUtils
|
||||
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2}
|
||||
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2}
|
||||
import org.apache.spark.sql.sources.BaseRelation
|
||||
import org.apache.spark.sql.sources.v2._
|
||||
import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode
|
||||
|
@ -243,8 +243,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
|
|||
assertNotBucketed("save")
|
||||
|
||||
val session = df.sparkSession
|
||||
val cls = DataSource.lookupDataSource(source, session.sessionState.conf)
|
||||
if (classOf[TableProvider].isAssignableFrom(cls)) {
|
||||
val useV1Sources =
|
||||
session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
|
||||
val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf)
|
||||
val cls = lookupCls.newInstance() match {
|
||||
case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) ||
|
||||
useV1Sources.contains(lookupCls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
|
||||
f.fallBackFileFormat
|
||||
case _ => lookupCls
|
||||
}
|
||||
// In Data Source V2 project, partitioning is still under development.
|
||||
// Here we fallback to V1 if partitioning columns are specified.
|
||||
// TODO(SPARK-26778): use V2 implementations when partitioning feature is supported.
|
||||
if (classOf[TableProvider].isAssignableFrom(cls) && partitioningColumns.isEmpty) {
|
||||
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
|
||||
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
|
||||
provider, session.sessionState.conf)
|
||||
|
|
|
@ -763,7 +763,7 @@ object DataSource extends Logging {
|
|||
* supplied schema is not empty.
|
||||
* @param schema
|
||||
*/
|
||||
private def validateSchema(schema: StructType): Unit = {
|
||||
def validateSchema(schema: StructType): Unit = {
|
||||
def hasEmptySchema(schema: StructType): Boolean = {
|
||||
schema.size == 0 || schema.find {
|
||||
case StructField(_, b: StructType, _, _) => hasEmptySchema(b)
|
||||
|
|
|
@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
|
|||
* Replace the ORC V2 data source of table in [[InsertIntoTable]] to V1 [[FileFormat]].
|
||||
* E.g, with temporary view `t` using [[FileDataSourceV2]], inserting into view `t` fails
|
||||
* since there is no corresponding physical plan.
|
||||
* SPARK-23817: This is a temporary hack for making current data source V2 work. It should be
|
||||
* removed when write path of file data source v2 is finished.
|
||||
* This is a temporary hack for making current data source V2 work. It should be
|
||||
* removed when Catalog support of file data source v2 is finished.
|
||||
*/
|
||||
class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
|
||||
import org.apache.spark.sql.types.StringType
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
|
@ -37,7 +38,7 @@ import org.apache.spark.util.SerializableConfiguration
|
|||
abstract class FileFormatDataWriter(
|
||||
description: WriteJobDescription,
|
||||
taskAttemptContext: TaskAttemptContext,
|
||||
committer: FileCommitProtocol) {
|
||||
committer: FileCommitProtocol) extends DataWriter[InternalRow] {
|
||||
/**
|
||||
* Max number of files a single task writes out due to file size. In most cases the number of
|
||||
* files written should be very small. This is just a safe guard to protect some really bad
|
||||
|
@ -70,7 +71,7 @@ abstract class FileFormatDataWriter(
|
|||
* to the driver and used to update the catalog. Other information will be sent back to the
|
||||
* driver too and used to e.g. update the metrics in UI.
|
||||
*/
|
||||
def commit(): WriteTaskResult = {
|
||||
override def commit(): WriteTaskResult = {
|
||||
releaseResources()
|
||||
val summary = ExecutedWriteSummary(
|
||||
updatedPartitions = updatedPartitions.toSet,
|
||||
|
@ -301,6 +302,7 @@ class WriteJobDescription(
|
|||
|
||||
/** The result of a successful write task. */
|
||||
case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
|
||||
extends WriterCommitMessage
|
||||
|
||||
/**
|
||||
* Wrapper class for the metrics of writing data out.
|
||||
|
|
|
@ -259,7 +259,7 @@ object FileFormatWriter extends Logging {
|
|||
* For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it
|
||||
* the corresponding [[WriteTaskStats]] from all executors.
|
||||
*/
|
||||
private def processStats(
|
||||
private[datasources] def processStats(
|
||||
statsTrackers: Seq[WriteJobStatsTracker],
|
||||
statsPerTask: Seq[Seq[WriteTaskStats]])
|
||||
: Unit = {
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.execution.datasources.OutputWriter
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
private[orc] class OrcOutputWriter(
|
||||
private[sql] class OrcOutputWriter(
|
||||
path: String,
|
||||
dataSchema: StructType,
|
||||
context: TaskAttemptContext)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.datasources.v2
|
||||
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.io.FileCommitProtocol
|
||||
import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult}
|
||||
import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats
|
||||
import org.apache.spark.sql.sources.v2.writer._
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
class FileBatchWrite(
|
||||
job: Job,
|
||||
description: WriteJobDescription,
|
||||
committer: FileCommitProtocol)
|
||||
extends BatchWrite with Logging {
|
||||
override def commit(messages: Array[WriterCommitMessage]): Unit = {
|
||||
val results = messages.map(_.asInstanceOf[WriteTaskResult])
|
||||
committer.commitJob(job, results.map(_.commitMsg))
|
||||
logInfo(s"Write Job ${description.uuid} committed.")
|
||||
|
||||
processStats(description.statsTrackers, results.map(_.summary.stats))
|
||||
logInfo(s"Finished processing stats for write job ${description.uuid}.")
|
||||
}
|
||||
|
||||
override def useCommitCoordinator(): Boolean = false
|
||||
|
||||
override def abort(messages: Array[WriterCommitMessage]): Unit = {
|
||||
committer.abortJob(job)
|
||||
}
|
||||
|
||||
override def createBatchWriterFactory(): DataWriterFactory = {
|
||||
val conf = new SerializableConfiguration(job.getConfiguration)
|
||||
FileWriterFactory(description, committer, conf)
|
||||
}
|
||||
}
|
||||
|
|
@ -20,13 +20,14 @@ import org.apache.hadoop.fs.FileStatus
|
|||
|
||||
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||
import org.apache.spark.sql.execution.datasources._
|
||||
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, Table}
|
||||
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
abstract class FileTable(
|
||||
sparkSession: SparkSession,
|
||||
fileIndex: PartitioningAwareFileIndex,
|
||||
userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead {
|
||||
userSpecifiedSchema: Option[StructType])
|
||||
extends Table with SupportsBatchRead with SupportsBatchWrite {
|
||||
def getFileIndex: PartitioningAwareFileIndex = this.fileIndex
|
||||
|
||||
lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.datasources.v2
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
|
||||
|
||||
import org.apache.spark.internal.io.FileCommitProtocol
|
||||
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
|
||||
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetric
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
abstract class FileWriteBuilder(options: DataSourceOptions)
|
||||
extends WriteBuilder with SupportsSaveMode {
|
||||
private var schema: StructType = _
|
||||
private var queryId: String = _
|
||||
private var mode: SaveMode = _
|
||||
|
||||
override def withInputDataSchema(schema: StructType): WriteBuilder = {
|
||||
this.schema = schema
|
||||
this
|
||||
}
|
||||
|
||||
override def withQueryId(queryId: String): WriteBuilder = {
|
||||
this.queryId = queryId
|
||||
this
|
||||
}
|
||||
|
||||
override def mode(mode: SaveMode): WriteBuilder = {
|
||||
this.mode = mode
|
||||
this
|
||||
}
|
||||
|
||||
override def buildForBatch(): BatchWrite = {
|
||||
validateInputs()
|
||||
val pathName = options.paths().head
|
||||
val path = new Path(pathName)
|
||||
val sparkSession = SparkSession.active
|
||||
val optionsAsScala = options.asMap().asScala.toMap
|
||||
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala)
|
||||
val job = getJobInstance(hadoopConf, path)
|
||||
val committer = FileCommitProtocol.instantiate(
|
||||
sparkSession.sessionState.conf.fileCommitProtocolClass,
|
||||
jobId = java.util.UUID.randomUUID().toString,
|
||||
outputPath = pathName)
|
||||
lazy val description =
|
||||
createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala)
|
||||
|
||||
val fs = path.getFileSystem(hadoopConf)
|
||||
mode match {
|
||||
case SaveMode.ErrorIfExists if fs.exists(path) =>
|
||||
val qualifiedOutputPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory)
|
||||
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
|
||||
|
||||
case SaveMode.Ignore if fs.exists(path) =>
|
||||
null
|
||||
|
||||
case SaveMode.Overwrite =>
|
||||
committer.deleteWithJob(fs, path, true)
|
||||
committer.setupJob(job)
|
||||
new FileBatchWrite(job, description, committer)
|
||||
|
||||
case _ =>
|
||||
committer.setupJob(job)
|
||||
new FileBatchWrite(job, description, committer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
|
||||
* be put here. For example, user defined output committer can be configured here
|
||||
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
|
||||
*/
|
||||
def prepareWrite(
|
||||
sqlConf: SQLConf,
|
||||
job: Job,
|
||||
options: Map[String, String],
|
||||
dataSchema: StructType): OutputWriterFactory
|
||||
|
||||
private def validateInputs(): Unit = {
|
||||
assert(schema != null, "Missing input data schema")
|
||||
assert(queryId != null, "Missing query ID")
|
||||
assert(mode != null, "Missing save mode")
|
||||
assert(options.paths().length == 1)
|
||||
DataSource.validateSchema(schema)
|
||||
}
|
||||
|
||||
private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
|
||||
val job = Job.getInstance(hadoopConf)
|
||||
job.setOutputKeyClass(classOf[Void])
|
||||
job.setOutputValueClass(classOf[InternalRow])
|
||||
FileOutputFormat.setOutputPath(job, path)
|
||||
job
|
||||
}
|
||||
|
||||
private def createWriteJobDescription(
|
||||
sparkSession: SparkSession,
|
||||
hadoopConf: Configuration,
|
||||
job: Job,
|
||||
pathName: String,
|
||||
options: Map[String, String]): WriteJobDescription = {
|
||||
val caseInsensitiveOptions = CaseInsensitiveMap(options)
|
||||
// Note: prepareWrite has side effect. It sets "job".
|
||||
val outputWriterFactory =
|
||||
prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema)
|
||||
val allColumns = schema.toAttributes
|
||||
val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
|
||||
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
|
||||
val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
|
||||
// TODO: after partitioning is supported in V2:
|
||||
// 1. filter out partition columns in `dataColumns`.
|
||||
// 2. Don't use Seq.empty for `partitionColumns`.
|
||||
new WriteJobDescription(
|
||||
uuid = UUID.randomUUID().toString,
|
||||
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
|
||||
outputWriterFactory = outputWriterFactory,
|
||||
allColumns = allColumns,
|
||||
dataColumns = allColumns,
|
||||
partitionColumns = Seq.empty,
|
||||
bucketIdExpression = None,
|
||||
path = pathName,
|
||||
customPartitionLocations = Map.empty,
|
||||
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
|
||||
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
|
||||
timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
|
||||
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
|
||||
statsTrackers = Seq(statsTracker)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.datasources.v2
|
||||
|
||||
import java.util.Date
|
||||
|
||||
import org.apache.hadoop.mapreduce.{TaskAttemptID, TaskID, TaskType}
|
||||
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
|
||||
|
||||
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, SingleDirectoryDataWriter, WriteJobDescription}
|
||||
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
case class FileWriterFactory (
|
||||
description: WriteJobDescription,
|
||||
committer: FileCommitProtocol,
|
||||
conf: SerializableConfiguration) extends DataWriterFactory {
|
||||
override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = {
|
||||
val taskAttemptContext = createTaskAttemptContext(partitionId)
|
||||
committer.setupTask(taskAttemptContext)
|
||||
if (description.partitionColumns.isEmpty) {
|
||||
new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
|
||||
} else {
|
||||
new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
|
||||
}
|
||||
}
|
||||
|
||||
private def createTaskAttemptContext(partitionId: Int): TaskAttemptContextImpl = {
|
||||
val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
|
||||
val taskId = new TaskID(jobId, TaskType.MAP, partitionId)
|
||||
val taskAttemptId = new TaskAttemptID(taskId, 0)
|
||||
// Set up the configuration object
|
||||
val hadoopConf = conf.value
|
||||
hadoopConf.set("mapreduce.job.id", jobId.toString)
|
||||
hadoopConf.set("mapreduce.task.id", taskId.toString)
|
||||
hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString)
|
||||
hadoopConf.setBoolean("mapreduce.task.ismap", true)
|
||||
hadoopConf.setInt("mapreduce.task.partition", 0)
|
||||
|
||||
new TaskAttemptContextImpl(hadoopConf, taskAttemptId)
|
||||
}
|
||||
}
|
|
@ -56,7 +56,14 @@ case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan)
|
|||
val writerFactory = batchWrite.createBatchWriterFactory()
|
||||
val useCommitCoordinator = batchWrite.useCommitCoordinator
|
||||
val rdd = query.execute()
|
||||
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
|
||||
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
|
||||
// partition rdd to make sure we at least set up one write task to write the metadata.
|
||||
val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
|
||||
sparkContext.parallelize(Array.empty[InternalRow], 1)
|
||||
} else {
|
||||
rdd
|
||||
}
|
||||
val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length)
|
||||
val totalNumRowsAccumulator = new LongAccumulator()
|
||||
|
||||
logInfo(s"Start processing data source write support: $batchWrite. " +
|
||||
|
@ -64,10 +71,10 @@ case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan)
|
|||
|
||||
try {
|
||||
sparkContext.runJob(
|
||||
rdd,
|
||||
rddWithNonEmptyPartitions,
|
||||
(context: TaskContext, iter: Iterator[InternalRow]) =>
|
||||
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
|
||||
rdd.partitions.indices,
|
||||
rddWithNonEmptyPartitions.partitions.indices,
|
||||
(index, result: DataWritingSparkTaskResult) => {
|
||||
val commitMessage = result.writerCommitMessage
|
||||
messages(index) = commitMessage
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
|
|||
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
|
||||
import org.apache.spark.sql.execution.datasources.v2.FileTable
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
case class OrcTable(
|
||||
|
@ -36,4 +37,7 @@ case class OrcTable(
|
|||
|
||||
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
|
||||
OrcUtils.readSchema(sparkSession, files)
|
||||
|
||||
override def newWriteBuilder(options: DataSourceOptions): WriteBuilder =
|
||||
new OrcWriteBuilder(options)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.execution.datasources.v2.orc
|
||||
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
|
||||
import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
|
||||
import org.apache.orc.mapred.OrcStruct
|
||||
|
||||
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
|
||||
import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils}
|
||||
import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.v2.DataSourceOptions
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(options) {
|
||||
override def prepareWrite(
|
||||
sqlConf: SQLConf,
|
||||
job: Job,
|
||||
options: Map[String, String],
|
||||
dataSchema: StructType): OutputWriterFactory = {
|
||||
val orcOptions = new OrcOptions(options, sqlConf)
|
||||
|
||||
val conf = job.getConfiguration
|
||||
|
||||
conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema))
|
||||
|
||||
conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
|
||||
|
||||
conf.asInstanceOf[JobConf]
|
||||
.setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])
|
||||
|
||||
new OutputWriterFactory {
|
||||
override def newInstance(
|
||||
path: String,
|
||||
dataSchema: StructType,
|
||||
context: TaskAttemptContext): OutputWriter = {
|
||||
new OrcOutputWriter(path, dataSchema, context)
|
||||
}
|
||||
|
||||
override def getFileExtension(context: TaskAttemptContext): String = {
|
||||
val compressionExtension: String = {
|
||||
val name = context.getConfiguration.get(COMPRESS.getAttribute)
|
||||
OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "")
|
||||
}
|
||||
|
||||
compressionExtension + ".orc"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -329,46 +329,49 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
|
|||
test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") {
|
||||
withTempDir { dir =>
|
||||
val tempDir = new File(dir, "files").getCanonicalPath
|
||||
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
|
||||
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
|
||||
// write path
|
||||
Seq("csv", "json", "parquet", "orc").foreach { format =>
|
||||
var msg = intercept[AnalysisException] {
|
||||
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
|
||||
}.getMessage
|
||||
assert(msg.contains("Cannot save interval data type into external storage."))
|
||||
|
||||
// write path
|
||||
Seq("csv", "json", "parquet", "orc").foreach { format =>
|
||||
var msg = intercept[AnalysisException] {
|
||||
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
|
||||
}.getMessage
|
||||
assert(msg.contains("Cannot save interval data type into external storage."))
|
||||
msg = intercept[AnalysisException] {
|
||||
spark.udf.register("testType", () => new IntervalData())
|
||||
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
}
|
||||
|
||||
msg = intercept[AnalysisException] {
|
||||
spark.udf.register("testType", () => new IntervalData())
|
||||
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
}
|
||||
// read path
|
||||
Seq("parquet", "csv").foreach { format =>
|
||||
var msg = intercept[AnalysisException] {
|
||||
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
|
||||
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
|
||||
spark.read.schema(schema).format(format).load(tempDir).collect()
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
|
||||
// read path
|
||||
Seq("parquet", "csv").foreach { format =>
|
||||
var msg = intercept[AnalysisException] {
|
||||
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
|
||||
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
|
||||
spark.read.schema(schema).format(format).load(tempDir).collect()
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
|
||||
msg = intercept[AnalysisException] {
|
||||
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
|
||||
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
|
||||
spark.read.schema(schema).format(format).load(tempDir).collect()
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
msg = intercept[AnalysisException] {
|
||||
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
|
||||
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
|
||||
spark.read.schema(schema).format(format).load(tempDir).collect()
|
||||
}.getMessage
|
||||
assert(msg.toLowerCase(Locale.ROOT)
|
||||
.contains(s"$format data source does not support calendarinterval data type."))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") {
|
||||
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
|
||||
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc") {
|
||||
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
|
||||
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
|
||||
withTempDir { dir =>
|
||||
val tempDir = new File(dir, "files").getCanonicalPath
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa
|
|||
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.v2.reader.ScanBuilder
|
||||
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
@ -46,9 +47,30 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead {
|
|||
}
|
||||
}
|
||||
|
||||
class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with SharedSQLContext {
|
||||
class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 {
|
||||
|
||||
override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat]
|
||||
|
||||
override def shortName(): String = "parquet"
|
||||
|
||||
override def getTable(options: DataSourceOptions): Table = {
|
||||
new DummyWriteOnlyFileTable
|
||||
}
|
||||
}
|
||||
|
||||
class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite {
|
||||
override def name(): String = "dummy"
|
||||
|
||||
override def schema(): StructType = StructType(Nil)
|
||||
|
||||
override def newWriteBuilder(options: DataSourceOptions): WriteBuilder =
|
||||
throw new AnalysisException("Dummy file writer")
|
||||
}
|
||||
|
||||
class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext {
|
||||
|
||||
private val dummyParquetReaderV2 = classOf[DummyReadOnlyFileDataSourceV2].getName
|
||||
private val dummyParquetWriterV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName
|
||||
|
||||
test("Fall back to v1 when writing to file with read only FileDataSourceV2") {
|
||||
val df = spark.range(10).toDF()
|
||||
|
@ -94,4 +116,47 @@ class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with Shar
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Fall back to v1 when reading file with write only FileDataSourceV2") {
|
||||
val df = spark.range(10).toDF()
|
||||
withTempPath { file =>
|
||||
val path = file.getCanonicalPath
|
||||
// Dummy File writer should fail as expected.
|
||||
val exception = intercept[AnalysisException] {
|
||||
df.write.format(dummyParquetWriterV2).save(path)
|
||||
}
|
||||
assert(exception.message.equals("Dummy file writer"))
|
||||
df.write.parquet(path)
|
||||
// Fallback reads to V1
|
||||
checkAnswer(spark.read.format(dummyParquetWriterV2).load(path), df)
|
||||
}
|
||||
}
|
||||
|
||||
test("Fall back write path to v1 with configuration USE_V1_SOURCE_WRITER_LIST") {
|
||||
val df = spark.range(10).toDF()
|
||||
Seq(
|
||||
"foo,parquet,bar",
|
||||
"ParQuet,bar,foo",
|
||||
s"foobar,$dummyParquetWriterV2"
|
||||
).foreach { fallbackWriters =>
|
||||
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> fallbackWriters) {
|
||||
withTempPath { file =>
|
||||
val path = file.getCanonicalPath
|
||||
// Writes should fall back to v1 and succeed.
|
||||
df.write.format(dummyParquetWriterV2).save(path)
|
||||
checkAnswer(spark.read.parquet(path), df)
|
||||
}
|
||||
}
|
||||
}
|
||||
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "foo,bar") {
|
||||
withTempPath { file =>
|
||||
val path = file.getCanonicalPath
|
||||
// Dummy File reader should fail as USE_V1_SOURCE_READER_LIST doesn't include it.
|
||||
val exception = intercept[AnalysisException] {
|
||||
df.write.format(dummyParquetWriterV2).save(path)
|
||||
}
|
||||
assert(exception.message.equals("Dummy file writer"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue