[SPARK-28554][SQL] Adds a v1 fallback writer implementation for v2 data source codepaths
## What changes were proposed in this pull request? This PR adds a V1 fallback interface for writing to V2 Tables using V1 Writer interfaces. The only supported SaveMode that will be called on the target table will be an Append. The target table must use V2 interfaces such as `SupportsOverwrite` or `SupportsTruncate` to support Overwrite operations. It is up to the target DataSource implementation if this operation can be atomic or not. We do not support dynamicPartitionOverwrite, as we cannot call a `commit` method that actually cleans up the data in the partitions that were touched through this fallback. ## How was this patch tested? Will add tests and example implementation after comments + feedback. This is a proposal at this point. Closes #25348 from brkyvz/v1WriteFallback. Lead-authored-by: Burak Yavuz <brkyvz@gmail.com> Co-authored-by: Burak Yavuz <burak@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
c4257b18a1
commit
4855bfe16b
|
@ -89,5 +89,14 @@ public enum TableCapability {
|
||||||
/**
|
/**
|
||||||
* Signals that the table accepts input of any schema in a write operation.
|
* Signals that the table accepts input of any schema in a write operation.
|
||||||
*/
|
*/
|
||||||
ACCEPT_ANY_SCHEMA
|
ACCEPT_ANY_SCHEMA,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Signals that the table supports append writes using the V1 InsertableRelation interface.
|
||||||
|
* <p>
|
||||||
|
* Tables that return this capability must create a V1WriteBuilder and may also support additional
|
||||||
|
* write modes, like {@link #TRUNCATE}, and {@link #OVERWRITE_BY_FILTER}, but cannot support
|
||||||
|
* {@link #OVERWRITE_DYNAMIC}.
|
||||||
|
*/
|
||||||
|
V1_BATCH_WRITE
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.datasources.v2
|
package org.apache.spark.sql.execution.datasources.v2
|
||||||
|
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
@ -29,8 +31,10 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
|
||||||
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
|
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
|
||||||
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
|
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
|
||||||
import org.apache.spark.sql.sources
|
import org.apache.spark.sql.sources
|
||||||
|
import org.apache.spark.sql.sources.v2.TableCapability
|
||||||
import org.apache.spark.sql.sources.v2.reader._
|
import org.apache.spark.sql.sources.v2.reader._
|
||||||
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
|
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
|
||||||
|
import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder
|
||||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
|
||||||
object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
|
@ -169,10 +173,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
catalog match {
|
catalog match {
|
||||||
case staging: StagingTableCatalog =>
|
case staging: StagingTableCatalog =>
|
||||||
AtomicCreateTableAsSelectExec(
|
AtomicCreateTableAsSelectExec(
|
||||||
staging, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil
|
staging, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil
|
||||||
case _ =>
|
case _ =>
|
||||||
CreateTableAsSelectExec(
|
CreateTableAsSelectExec(
|
||||||
catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil
|
catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
|
case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
|
||||||
|
@ -191,6 +195,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
staging,
|
staging,
|
||||||
ident,
|
ident,
|
||||||
parts,
|
parts,
|
||||||
|
query,
|
||||||
planLater(query),
|
planLater(query),
|
||||||
props,
|
props,
|
||||||
writeOptions,
|
writeOptions,
|
||||||
|
@ -200,6 +205,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
catalog,
|
catalog,
|
||||||
ident,
|
ident,
|
||||||
parts,
|
parts,
|
||||||
|
query,
|
||||||
planLater(query),
|
planLater(query),
|
||||||
props,
|
props,
|
||||||
writeOptions,
|
writeOptions,
|
||||||
|
@ -207,7 +213,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
case AppendData(r: DataSourceV2Relation, query, _) =>
|
case AppendData(r: DataSourceV2Relation, query, _) =>
|
||||||
AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil
|
r.table.asWritable match {
|
||||||
|
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
|
||||||
|
AppendDataExecV1(v1, r.options, query) :: Nil
|
||||||
|
case v2 =>
|
||||||
|
AppendDataExec(v2, r.options, planLater(query)) :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) =>
|
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) =>
|
||||||
// fail if any filter cannot be converted. correctness depends on removing all matching data.
|
// fail if any filter cannot be converted. correctness depends on removing all matching data.
|
||||||
|
@ -215,9 +226,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
|
||||||
filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse(
|
filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse(
|
||||||
throw new AnalysisException(s"Cannot translate expression to source filter: $filter"))
|
throw new AnalysisException(s"Cannot translate expression to source filter: $filter"))
|
||||||
}.toArray
|
}.toArray
|
||||||
|
r.table.asWritable match {
|
||||||
OverwriteByExpressionExec(
|
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
|
||||||
r.table.asWritable, filters, r.options, planLater(query)) :: Nil
|
OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil
|
||||||
|
case v2 =>
|
||||||
|
OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) =>
|
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) =>
|
||||||
OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil
|
OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.execution.datasources.v2
|
||||||
|
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import org.apache.spark.SparkException
|
||||||
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.sql.{Dataset, SaveMode}
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||||
|
import org.apache.spark.sql.execution.SparkPlan
|
||||||
|
import org.apache.spark.sql.sources.{AlwaysTrue, CreatableRelationProvider, Filter, InsertableRelation}
|
||||||
|
import org.apache.spark.sql.sources.v2.{SupportsWrite, Table}
|
||||||
|
import org.apache.spark.sql.sources.v2.writer._
|
||||||
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Physical plan node for append into a v2 table using V1 write interfaces.
|
||||||
|
*
|
||||||
|
* Rows in the output data set are appended.
|
||||||
|
*/
|
||||||
|
case class AppendDataExecV1(
|
||||||
|
table: SupportsWrite,
|
||||||
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
|
plan: LogicalPlan) extends V1FallbackWriters {
|
||||||
|
|
||||||
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
|
writeWithV1(newWriteBuilder().buildForV1Write())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Physical plan node for overwrite into a v2 table with V1 write interfaces. Note that when this
|
||||||
|
* interface is used, the atomicity of the operation depends solely on the target data source.
|
||||||
|
*
|
||||||
|
* Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be
|
||||||
|
* deleted and rows in the output data set are appended.
|
||||||
|
*
|
||||||
|
* This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to
|
||||||
|
* truncate the table -- delete all rows -- and append the output data set. This uses the filter
|
||||||
|
* AlwaysTrue to delete all rows.
|
||||||
|
*/
|
||||||
|
case class OverwriteByExpressionExecV1(
|
||||||
|
table: SupportsWrite,
|
||||||
|
deleteWhere: Array[Filter],
|
||||||
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
|
plan: LogicalPlan) extends V1FallbackWriters {
|
||||||
|
|
||||||
|
private def isTruncate(filters: Array[Filter]): Boolean = {
|
||||||
|
filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
|
newWriteBuilder() match {
|
||||||
|
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
|
||||||
|
writeWithV1(builder.truncate().asV1Builder.buildForV1Write())
|
||||||
|
|
||||||
|
case builder: SupportsOverwrite =>
|
||||||
|
writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write())
|
||||||
|
|
||||||
|
case _ =>
|
||||||
|
throw new SparkException(s"Table does not support overwrite by expression: $table")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
|
||||||
|
sealed trait V1FallbackWriters extends SupportsV1Write {
|
||||||
|
override def output: Seq[Attribute] = Nil
|
||||||
|
override final def children: Seq[SparkPlan] = Nil
|
||||||
|
|
||||||
|
def table: SupportsWrite
|
||||||
|
def writeOptions: CaseInsensitiveStringMap
|
||||||
|
|
||||||
|
protected implicit class toV1WriteBuilder(builder: WriteBuilder) {
|
||||||
|
def asV1Builder: V1WriteBuilder = builder match {
|
||||||
|
case v1: V1WriteBuilder => v1
|
||||||
|
case other => throw new IllegalStateException(
|
||||||
|
s"The returned writer ${other} was no longer a V1WriteBuilder.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def newWriteBuilder(): V1WriteBuilder = {
|
||||||
|
val writeBuilder = table.newWriteBuilder(writeOptions)
|
||||||
|
.withInputDataSchema(plan.schema)
|
||||||
|
.withQueryId(UUID.randomUUID().toString)
|
||||||
|
writeBuilder.asV1Builder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A trait that allows Tables that use V1 Writer interfaces to append data.
|
||||||
|
*/
|
||||||
|
trait SupportsV1Write extends SparkPlan {
|
||||||
|
// TODO: We should be able to work on SparkPlans at this point.
|
||||||
|
def plan: LogicalPlan
|
||||||
|
|
||||||
|
protected def writeWithV1(relation: InsertableRelation): RDD[InternalRow] = {
|
||||||
|
relation.insert(Dataset.ofRows(sqlContext.sparkSession, plan), overwrite = false)
|
||||||
|
sparkContext.emptyRDD
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||||
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
||||||
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
|
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
|
||||||
import org.apache.spark.sql.sources.v2.{StagedTable, SupportsWrite}
|
import org.apache.spark.sql.sources.v2.{StagedTable, SupportsWrite}
|
||||||
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage}
|
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
|
||||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
import org.apache.spark.util.{LongAccumulator, Utils}
|
import org.apache.spark.util.{LongAccumulator, Utils}
|
||||||
|
|
||||||
|
@ -63,10 +63,11 @@ case class CreateTableAsSelectExec(
|
||||||
catalog: TableCatalog,
|
catalog: TableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
|
plan: LogicalPlan,
|
||||||
query: SparkPlan,
|
query: SparkPlan,
|
||||||
properties: Map[String, String],
|
properties: Map[String, String],
|
||||||
writeOptions: CaseInsensitiveStringMap,
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
ifNotExists: Boolean) extends V2TableWriteExec {
|
ifNotExists: Boolean) extends V2TableWriteExec with SupportsV1Write {
|
||||||
|
|
||||||
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
||||||
|
|
||||||
|
@ -83,12 +84,14 @@ case class CreateTableAsSelectExec(
|
||||||
catalog.createTable(
|
catalog.createTable(
|
||||||
ident, query.schema, partitioning.toArray, properties.asJava) match {
|
ident, query.schema, partitioning.toArray, properties.asJava) match {
|
||||||
case table: SupportsWrite =>
|
case table: SupportsWrite =>
|
||||||
val batchWrite = table.newWriteBuilder(writeOptions)
|
val writeBuilder = table.newWriteBuilder(writeOptions)
|
||||||
.withInputDataSchema(query.schema)
|
.withInputDataSchema(query.schema)
|
||||||
.withQueryId(UUID.randomUUID().toString)
|
.withQueryId(UUID.randomUUID().toString)
|
||||||
.buildForBatch()
|
|
||||||
|
|
||||||
doWrite(batchWrite)
|
writeBuilder match {
|
||||||
|
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
|
||||||
|
case v2 => writeWithV2(v2.buildForBatch())
|
||||||
|
}
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
// table does not support writes
|
// table does not support writes
|
||||||
|
@ -114,6 +117,7 @@ case class AtomicCreateTableAsSelectExec(
|
||||||
catalog: StagingTableCatalog,
|
catalog: StagingTableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
|
plan: LogicalPlan,
|
||||||
query: SparkPlan,
|
query: SparkPlan,
|
||||||
properties: Map[String, String],
|
properties: Map[String, String],
|
||||||
writeOptions: CaseInsensitiveStringMap,
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
|
@ -147,10 +151,11 @@ case class ReplaceTableAsSelectExec(
|
||||||
catalog: TableCatalog,
|
catalog: TableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
|
plan: LogicalPlan,
|
||||||
query: SparkPlan,
|
query: SparkPlan,
|
||||||
properties: Map[String, String],
|
properties: Map[String, String],
|
||||||
writeOptions: CaseInsensitiveStringMap,
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
orCreate: Boolean) extends AtomicTableWriteExec {
|
orCreate: Boolean) extends V2TableWriteExec with SupportsV1Write {
|
||||||
|
|
||||||
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
||||||
|
|
||||||
|
@ -173,12 +178,14 @@ case class ReplaceTableAsSelectExec(
|
||||||
Utils.tryWithSafeFinallyAndFailureCallbacks({
|
Utils.tryWithSafeFinallyAndFailureCallbacks({
|
||||||
createdTable match {
|
createdTable match {
|
||||||
case table: SupportsWrite =>
|
case table: SupportsWrite =>
|
||||||
val batchWrite = table.newWriteBuilder(writeOptions)
|
val writeBuilder = table.newWriteBuilder(writeOptions)
|
||||||
.withInputDataSchema(query.schema)
|
.withInputDataSchema(query.schema)
|
||||||
.withQueryId(UUID.randomUUID().toString)
|
.withQueryId(UUID.randomUUID().toString)
|
||||||
.buildForBatch()
|
|
||||||
|
|
||||||
doWrite(batchWrite)
|
writeBuilder match {
|
||||||
|
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
|
||||||
|
case v2 => writeWithV2(v2.buildForBatch())
|
||||||
|
}
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
// table does not support writes
|
// table does not support writes
|
||||||
|
@ -207,6 +214,7 @@ case class AtomicReplaceTableAsSelectExec(
|
||||||
catalog: StagingTableCatalog,
|
catalog: StagingTableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
|
plan: LogicalPlan,
|
||||||
query: SparkPlan,
|
query: SparkPlan,
|
||||||
properties: Map[String, String],
|
properties: Map[String, String],
|
||||||
writeOptions: CaseInsensitiveStringMap,
|
writeOptions: CaseInsensitiveStringMap,
|
||||||
|
@ -242,8 +250,7 @@ case class AppendDataExec(
|
||||||
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
|
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
|
||||||
|
|
||||||
override protected def doExecute(): RDD[InternalRow] = {
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
val batchWrite = newWriteBuilder().buildForBatch()
|
writeWithV2(newWriteBuilder().buildForBatch())
|
||||||
doWrite(batchWrite)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,18 +275,16 @@ case class OverwriteByExpressionExec(
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def doExecute(): RDD[InternalRow] = {
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
val batchWrite = newWriteBuilder() match {
|
newWriteBuilder() match {
|
||||||
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
|
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
|
||||||
builder.truncate().buildForBatch()
|
writeWithV2(builder.truncate().buildForBatch())
|
||||||
|
|
||||||
case builder: SupportsOverwrite =>
|
case builder: SupportsOverwrite =>
|
||||||
builder.overwrite(deleteWhere).buildForBatch()
|
writeWithV2(builder.overwrite(deleteWhere).buildForBatch())
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new SparkException(s"Table does not support overwrite by expression: $table")
|
throw new SparkException(s"Table does not support overwrite by expression: $table")
|
||||||
}
|
}
|
||||||
|
|
||||||
doWrite(batchWrite)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,15 +303,13 @@ case class OverwritePartitionsDynamicExec(
|
||||||
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
|
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
|
||||||
|
|
||||||
override protected def doExecute(): RDD[InternalRow] = {
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
val batchWrite = newWriteBuilder() match {
|
newWriteBuilder() match {
|
||||||
case builder: SupportsDynamicOverwrite =>
|
case builder: SupportsDynamicOverwrite =>
|
||||||
builder.overwriteDynamicPartitions().buildForBatch()
|
writeWithV2(builder.overwriteDynamicPartitions().buildForBatch())
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
|
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
|
||||||
}
|
}
|
||||||
|
|
||||||
doWrite(batchWrite)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,7 +320,7 @@ case class WriteToDataSourceV2Exec(
|
||||||
def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()
|
def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()
|
||||||
|
|
||||||
override protected def doExecute(): RDD[InternalRow] = {
|
override protected def doExecute(): RDD[InternalRow] = {
|
||||||
doWrite(batchWrite)
|
writeWithV2(batchWrite)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,8 +334,8 @@ trait BatchWriteHelper {
|
||||||
|
|
||||||
def newWriteBuilder(): WriteBuilder = {
|
def newWriteBuilder(): WriteBuilder = {
|
||||||
table.newWriteBuilder(writeOptions)
|
table.newWriteBuilder(writeOptions)
|
||||||
.withInputDataSchema(query.schema)
|
.withInputDataSchema(query.schema)
|
||||||
.withQueryId(UUID.randomUUID().toString)
|
.withQueryId(UUID.randomUUID().toString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -347,7 +350,7 @@ trait V2TableWriteExec extends UnaryExecNode {
|
||||||
override def child: SparkPlan = query
|
override def child: SparkPlan = query
|
||||||
override def output: Seq[Attribute] = Nil
|
override def output: Seq[Attribute] = Nil
|
||||||
|
|
||||||
protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = {
|
protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
|
||||||
val writerFactory = batchWrite.createBatchWriterFactory()
|
val writerFactory = batchWrite.createBatchWriterFactory()
|
||||||
val useCommitCoordinator = batchWrite.useCommitCoordinator
|
val useCommitCoordinator = batchWrite.useCommitCoordinator
|
||||||
val rdd = query.execute()
|
val rdd = query.execute()
|
||||||
|
@ -463,7 +466,7 @@ object DataWritingSparkTask extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[v2] trait AtomicTableWriteExec extends V2TableWriteExec {
|
private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1Write {
|
||||||
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
|
||||||
|
|
||||||
protected def writeToStagedTable(
|
protected def writeToStagedTable(
|
||||||
|
@ -473,14 +476,17 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec {
|
||||||
Utils.tryWithSafeFinallyAndFailureCallbacks({
|
Utils.tryWithSafeFinallyAndFailureCallbacks({
|
||||||
stagedTable match {
|
stagedTable match {
|
||||||
case table: SupportsWrite =>
|
case table: SupportsWrite =>
|
||||||
val batchWrite = table.newWriteBuilder(writeOptions)
|
val writeBuilder = table.newWriteBuilder(writeOptions)
|
||||||
.withInputDataSchema(query.schema)
|
.withInputDataSchema(query.schema)
|
||||||
.withQueryId(UUID.randomUUID().toString)
|
.withQueryId(UUID.randomUUID().toString)
|
||||||
.buildForBatch()
|
|
||||||
|
|
||||||
val writtenRows = doWrite(batchWrite)
|
val writtenRows = writeBuilder match {
|
||||||
|
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
|
||||||
|
case v2 => writeWithV2(v2.buildForBatch())
|
||||||
|
}
|
||||||
stagedTable.commitStagedChanges()
|
stagedTable.commitStagedChanges()
|
||||||
writtenRows
|
writtenRows
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
// Table does not support writes - staged changes are also rolled back below.
|
// Table does not support writes - staged changes are also rolled back below.
|
||||||
throw new SparkException(
|
throw new SparkException(
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.sources.v2.writer
|
||||||
|
|
||||||
|
import org.apache.spark.annotation.{Experimental, Unstable}
|
||||||
|
import org.apache.spark.sql.sources.InsertableRelation
|
||||||
|
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A trait that should be implemented by V1 DataSources that would like to leverage the DataSource
|
||||||
|
* V2 write code paths. The InsertableRelation will be used only to Append data. Other
|
||||||
|
* instances of the [[WriteBuilder]] interface such as [[SupportsOverwrite]], [[SupportsTruncate]]
|
||||||
|
* should be extended as well to support additional operations other than data appends.
|
||||||
|
*
|
||||||
|
* This interface is designed to provide Spark DataSources time to migrate to DataSource V2 and
|
||||||
|
* will be removed in a future Spark release.
|
||||||
|
*
|
||||||
|
* @since 3.0.0
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
@Unstable
|
||||||
|
trait V1WriteBuilder extends WriteBuilder {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an InsertableRelation that allows appending a DataFrame to a
|
||||||
|
* a destination (using data source-specific parameters). The insert method will only be
|
||||||
|
* called with `overwrite=false`. The DataSource should implement the overwrite behavior as
|
||||||
|
* part of the [[SupportsOverwrite]], and [[SupportsTruncate]] interfaces.
|
||||||
|
*
|
||||||
|
* @since 3.0.0
|
||||||
|
*/
|
||||||
|
def buildForV1Write(): InsertableRelation
|
||||||
|
|
||||||
|
// These methods cannot be implemented by a V1WriteBuilder. The super class will throw
|
||||||
|
// an Unsupported OperationException
|
||||||
|
override final def buildForBatch(): BatchWrite = super.buildForBatch()
|
||||||
|
|
||||||
|
override final def buildForStreaming(): StreamingWrite = super.buildForStreaming()
|
||||||
|
}
|
|
@ -31,39 +31,90 @@ import org.apache.spark.sql.catalyst.TableIdentifier
|
||||||
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
|
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
|
||||||
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
|
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase
|
||||||
import org.apache.spark.sql.test.SharedSparkSession
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
|
||||||
class DataSourceV2DataFrameSessionCatalogSuite
|
class DataSourceV2DataFrameSessionCatalogSuite
|
||||||
|
extends SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] {
|
||||||
|
|
||||||
|
test("saveAsTable: Append mode should not fail if the table already exists " +
|
||||||
|
"and a same-name temp view exist") {
|
||||||
|
withTable("same_name") {
|
||||||
|
withTempView("same_name") {
|
||||||
|
val format = spark.sessionState.conf.defaultDataSourceName
|
||||||
|
sql(s"CREATE TABLE same_name(id LONG) USING $format")
|
||||||
|
spark.range(10).createTempView("same_name")
|
||||||
|
spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name")
|
||||||
|
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
|
||||||
|
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("saveAsTable with mode Overwrite should not fail if the table already exists " +
|
||||||
|
"and a same-name temp view exist") {
|
||||||
|
withTable("same_name") {
|
||||||
|
withTempView("same_name") {
|
||||||
|
sql(s"CREATE TABLE same_name(id LONG) USING $v2Format")
|
||||||
|
spark.range(10).createTempView("same_name")
|
||||||
|
spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name")
|
||||||
|
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
|
||||||
|
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class InMemoryTableProvider extends TableProvider {
|
||||||
|
override def getTable(options: CaseInsensitiveStringMap): Table = {
|
||||||
|
throw new UnsupportedOperationException("D'oh!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] {
|
||||||
|
override def newTable(
|
||||||
|
name: String,
|
||||||
|
schema: StructType,
|
||||||
|
partitions: Array[Transform],
|
||||||
|
properties: util.Map[String, String]): InMemoryTable = {
|
||||||
|
new InMemoryTable(name, schema, partitions, properties)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private[v2] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]]
|
||||||
extends QueryTest
|
extends QueryTest
|
||||||
with SharedSparkSession
|
with SharedSparkSession
|
||||||
with BeforeAndAfter {
|
with BeforeAndAfter {
|
||||||
import testImplicits._
|
|
||||||
|
|
||||||
private def catalog(name: String): CatalogPlugin = {
|
protected def catalog(name: String): CatalogPlugin = {
|
||||||
spark.sessionState.catalogManager.catalog(name)
|
spark.sessionState.catalogManager.catalog(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
private val v2Format = classOf[InMemoryTableProvider].getName
|
protected val v2Format = classOf[InMemoryTableProvider].getName
|
||||||
|
|
||||||
|
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName
|
||||||
|
|
||||||
before {
|
before {
|
||||||
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[TestV2SessionCatalog].getName)
|
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, catalogClassName)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def afterEach(): Unit = {
|
override def afterEach(): Unit = {
|
||||||
super.afterEach()
|
super.afterEach()
|
||||||
catalog("session").asInstanceOf[TestV2SessionCatalog].clearTables()
|
catalog("session").asInstanceOf[Catalog].clearTables()
|
||||||
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName)
|
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def verifyTable(tableName: String, expected: DataFrame): Unit = {
|
protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
|
||||||
checkAnswer(spark.table(tableName), expected)
|
checkAnswer(spark.table(tableName), expected)
|
||||||
checkAnswer(sql(s"SELECT * FROM $tableName"), expected)
|
checkAnswer(sql(s"SELECT * FROM $tableName"), expected)
|
||||||
checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected)
|
checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected)
|
||||||
checkAnswer(sql(s"TABLE $tableName"), expected)
|
checkAnswer(sql(s"TABLE $tableName"), expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import testImplicits._
|
||||||
|
|
||||||
test("saveAsTable: v2 table - table doesn't exist and default mode (ErrorIfExists)") {
|
test("saveAsTable: v2 table - table doesn't exist and default mode (ErrorIfExists)") {
|
||||||
val t1 = "tbl"
|
val t1 = "tbl"
|
||||||
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
||||||
|
@ -90,20 +141,6 @@ class DataSourceV2DataFrameSessionCatalogSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("saveAsTable: Append mode should not fail if the table already exists " +
|
|
||||||
"and a same-name temp view exist") {
|
|
||||||
withTable("same_name") {
|
|
||||||
withTempView("same_name") {
|
|
||||||
val format = spark.sessionState.conf.defaultDataSourceName
|
|
||||||
sql(s"CREATE TABLE same_name(id LONG) USING $format")
|
|
||||||
spark.range(10).createTempView("same_name")
|
|
||||||
spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name")
|
|
||||||
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
|
|
||||||
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("saveAsTable: v2 table - table exists") {
|
test("saveAsTable: v2 table - table exists") {
|
||||||
val t1 = "tbl"
|
val t1 = "tbl"
|
||||||
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
||||||
|
@ -147,19 +184,6 @@ class DataSourceV2DataFrameSessionCatalogSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("saveAsTable with mode Overwrite should not fail if the table already exists " +
|
|
||||||
"and a same-name temp view exist") {
|
|
||||||
withTable("same_name") {
|
|
||||||
withTempView("same_name") {
|
|
||||||
sql(s"CREATE TABLE same_name(id LONG) USING $v2Format")
|
|
||||||
spark.range(10).createTempView("same_name")
|
|
||||||
spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name")
|
|
||||||
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
|
|
||||||
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
test("saveAsTable: v2 table - ignore mode and table doesn't exist") {
|
test("saveAsTable: v2 table - ignore mode and table doesn't exist") {
|
||||||
val t1 = "tbl"
|
val t1 = "tbl"
|
||||||
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
|
||||||
|
@ -175,55 +199,3 @@ class DataSourceV2DataFrameSessionCatalogSuite
|
||||||
verifyTable(t1, Seq(("c", "d")).toDF("id", "data"))
|
verifyTable(t1, Seq(("c", "d")).toDF("id", "data"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class InMemoryTableProvider extends TableProvider {
|
|
||||||
override def getTable(options: CaseInsensitiveStringMap): Table = {
|
|
||||||
throw new UnsupportedOperationException("D'oh!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** A SessionCatalog that always loads an in memory Table, so we can test write code paths. */
|
|
||||||
class TestV2SessionCatalog extends V2SessionCatalog {
|
|
||||||
|
|
||||||
protected val tables: util.Map[Identifier, InMemoryTable] =
|
|
||||||
new ConcurrentHashMap[Identifier, InMemoryTable]()
|
|
||||||
|
|
||||||
private def fullIdentifier(ident: Identifier): Identifier = {
|
|
||||||
if (ident.namespace().isEmpty) {
|
|
||||||
Identifier.of(Array("default"), ident.name())
|
|
||||||
} else {
|
|
||||||
ident
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def loadTable(ident: Identifier): Table = {
|
|
||||||
val fullIdent = fullIdentifier(ident)
|
|
||||||
if (tables.containsKey(fullIdent)) {
|
|
||||||
tables.get(fullIdent)
|
|
||||||
} else {
|
|
||||||
// Table was created through the built-in catalog
|
|
||||||
val t = super.loadTable(fullIdent)
|
|
||||||
val table = new InMemoryTable(t.name(), t.schema(), t.partitioning(), t.properties())
|
|
||||||
tables.put(fullIdent, table)
|
|
||||||
table
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override def createTable(
|
|
||||||
ident: Identifier,
|
|
||||||
schema: StructType,
|
|
||||||
partitions: Array[Transform],
|
|
||||||
properties: util.Map[String, String]): Table = {
|
|
||||||
val created = super.createTable(ident, schema, partitions, properties)
|
|
||||||
val t = new InMemoryTable(created.name(), schema, partitions, properties)
|
|
||||||
val fullIdent = fullIdentifier(ident)
|
|
||||||
tables.put(fullIdent, t)
|
|
||||||
t
|
|
||||||
}
|
|
||||||
|
|
||||||
def clearTables(): Unit = {
|
|
||||||
assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?")
|
|
||||||
tables.keySet().asScala.foreach(super.dropTable)
|
|
||||||
tables.clear()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -234,7 +234,8 @@ class InMemoryTable(
|
||||||
|
|
||||||
private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
|
private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
|
||||||
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
|
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
|
||||||
dataMap --= deletesKeys(filters)
|
val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
|
||||||
|
dataMap --= deleteKeys
|
||||||
withData(messages.map(_.asInstanceOf[BufferedRows]))
|
withData(messages.map(_.asInstanceOf[BufferedRows]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -247,7 +248,37 @@ class InMemoryTable(
|
||||||
}
|
}
|
||||||
|
|
||||||
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
|
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
|
||||||
dataMap --= deletesKeys(filters)
|
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object InMemoryTable {
|
||||||
|
def filtersToKeys(
|
||||||
|
keys: Iterable[Seq[Any]],
|
||||||
|
partitionNames: Seq[String],
|
||||||
|
filters: Array[Filter]): Iterable[Seq[Any]] = {
|
||||||
|
keys.filter { partValues =>
|
||||||
|
filters.flatMap(splitAnd).forall {
|
||||||
|
case EqualTo(attr, value) =>
|
||||||
|
value == extractValue(attr, partitionNames, partValues)
|
||||||
|
case IsNotNull(attr) =>
|
||||||
|
null != extractValue(attr, partitionNames, partValues)
|
||||||
|
case f =>
|
||||||
|
throw new IllegalArgumentException(s"Unsupported filter type: $f")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def extractValue(
|
||||||
|
attr: String,
|
||||||
|
partFieldNames: Seq[String],
|
||||||
|
partValues: Seq[Any]): Any = {
|
||||||
|
partFieldNames.zipWithIndex.find(_._1 == attr) match {
|
||||||
|
case Some((_, partIndex)) =>
|
||||||
|
partValues(partIndex)
|
||||||
|
case _ =>
|
||||||
|
throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def splitAnd(filter: Filter): Seq[Filter] = {
|
private def splitAnd(filter: Filter): Seq[Filter] = {
|
||||||
|
@ -256,30 +287,6 @@ class InMemoryTable(
|
||||||
case _ => filter :: Nil
|
case _ => filter :: Nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def deletesKeys(filters: Array[Filter]): Iterable[Seq[Any]] = {
|
|
||||||
dataMap.synchronized {
|
|
||||||
dataMap.keys.filter { partValues =>
|
|
||||||
filters.flatMap(splitAnd).forall {
|
|
||||||
case EqualTo(attr, value) =>
|
|
||||||
value == extractValue(attr, partValues)
|
|
||||||
case IsNotNull(attr) =>
|
|
||||||
null != extractValue(attr, partValues)
|
|
||||||
case f =>
|
|
||||||
throw new IllegalArgumentException(s"Unsupported filter type: $f")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def extractValue(attr: String, partValues: Seq[Any]): Any = {
|
|
||||||
partFieldNames.zipWithIndex.find(_._1 == attr) match {
|
|
||||||
case Some((_, partIndex)) =>
|
|
||||||
partValues(partIndex)
|
|
||||||
case _ =>
|
|
||||||
throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
object TestInMemoryTableCatalog {
|
object TestInMemoryTableCatalog {
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.sources.v2
|
||||||
|
|
||||||
|
import java.util
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import org.scalatest.BeforeAndAfter
|
||||||
|
|
||||||
|
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession}
|
||||||
|
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform}
|
||||||
|
import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation}
|
||||||
|
import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase
|
||||||
|
import org.apache.spark.sql.sources.v2.writer.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
|
||||||
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
|
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
|
||||||
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
|
||||||
|
class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with BeforeAndAfter {
|
||||||
|
|
||||||
|
import testImplicits._
|
||||||
|
|
||||||
|
private val v2Format = classOf[InMemoryV1Provider].getName
|
||||||
|
|
||||||
|
override def beforeAll(): Unit = {
|
||||||
|
super.beforeAll()
|
||||||
|
InMemoryV1Provider.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
override def afterEach(): Unit = {
|
||||||
|
super.afterEach()
|
||||||
|
InMemoryV1Provider.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("append fallback") {
|
||||||
|
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
|
||||||
|
df.write.mode("append").option("name", "t1").format(v2Format).save()
|
||||||
|
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
|
||||||
|
df.write.mode("append").option("name", "t1").format(v2Format).save()
|
||||||
|
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("overwrite by truncate fallback") {
|
||||||
|
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
|
||||||
|
df.write.mode("append").option("name", "t1").format(v2Format).save()
|
||||||
|
|
||||||
|
val df2 = Seq((10, "k"), (20, "l"), (30, "m")).toDF("a", "b")
|
||||||
|
df2.write.mode("overwrite").option("name", "t1").format(v2Format).save()
|
||||||
|
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class V1WriteFallbackSessionCatalogSuite
|
||||||
|
extends SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] {
|
||||||
|
override protected val v2Format = classOf[InMemoryV1Provider].getName
|
||||||
|
override protected val catalogClassName: String = classOf[V1FallbackTableCatalog].getName
|
||||||
|
|
||||||
|
override protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
|
||||||
|
checkAnswer(InMemoryV1Provider.getTableData(spark, s"default.$tableName"), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV1Fallback] {
|
||||||
|
override def newTable(
|
||||||
|
name: String,
|
||||||
|
schema: StructType,
|
||||||
|
partitions: Array[Transform],
|
||||||
|
properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
|
||||||
|
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
|
||||||
|
InMemoryV1Provider.tables.put(name, t)
|
||||||
|
t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private object InMemoryV1Provider {
|
||||||
|
val tables: mutable.Map[String, InMemoryTableWithV1Fallback] = mutable.Map.empty
|
||||||
|
|
||||||
|
def getTableData(spark: SparkSession, name: String): DataFrame = {
|
||||||
|
val t = tables.getOrElse(name, throw new IllegalArgumentException(s"Table $name doesn't exist"))
|
||||||
|
spark.createDataFrame(t.getData.asJava, t.schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(): Unit = {
|
||||||
|
tables.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class InMemoryV1Provider extends TableProvider with DataSourceRegister {
|
||||||
|
override def getTable(options: CaseInsensitiveStringMap): Table = {
|
||||||
|
InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {
|
||||||
|
new InMemoryTableWithV1Fallback(
|
||||||
|
"InMemoryTableWithV1Fallback",
|
||||||
|
new StructType().add("a", IntegerType).add("b", StringType),
|
||||||
|
Array(IdentityTransform(FieldReference(Seq("a")))),
|
||||||
|
options.asCaseSensitiveMap()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
override def shortName(): String = "in-memory"
|
||||||
|
}
|
||||||
|
|
||||||
|
class InMemoryTableWithV1Fallback(
|
||||||
|
override val name: String,
|
||||||
|
override val schema: StructType,
|
||||||
|
override val partitioning: Array[Transform],
|
||||||
|
override val properties: util.Map[String, String]) extends Table with SupportsWrite {
|
||||||
|
|
||||||
|
partitioning.foreach { t =>
|
||||||
|
if (!t.isInstanceOf[IdentityTransform]) {
|
||||||
|
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def capabilities: util.Set[TableCapability] = Set(
|
||||||
|
TableCapability.BATCH_WRITE,
|
||||||
|
TableCapability.V1_BATCH_WRITE,
|
||||||
|
TableCapability.OVERWRITE_BY_FILTER,
|
||||||
|
TableCapability.TRUNCATE).asJava
|
||||||
|
|
||||||
|
@volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty
|
||||||
|
private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames)
|
||||||
|
private val partIndexes = partFieldNames.map(schema.fieldIndex(_))
|
||||||
|
|
||||||
|
def getData: Seq[Row] = dataMap.values.flatten.toSeq
|
||||||
|
|
||||||
|
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
|
||||||
|
new FallbackWriteBuilder(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
private class FallbackWriteBuilder(options: CaseInsensitiveStringMap)
|
||||||
|
extends WriteBuilder
|
||||||
|
with V1WriteBuilder
|
||||||
|
with SupportsTruncate
|
||||||
|
with SupportsOverwrite {
|
||||||
|
|
||||||
|
private var mode = "append"
|
||||||
|
|
||||||
|
override def truncate(): WriteBuilder = {
|
||||||
|
dataMap.clear()
|
||||||
|
mode = "truncate"
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
override def overwrite(filters: Array[Filter]): WriteBuilder = {
|
||||||
|
val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
|
||||||
|
dataMap --= keys
|
||||||
|
mode = "overwrite"
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
private def getPartitionValues(row: Row): Seq[Any] = {
|
||||||
|
partIndexes.map(row.get)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def buildForV1Write(): InsertableRelation = {
|
||||||
|
new InsertableRelation {
|
||||||
|
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
|
||||||
|
assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true")
|
||||||
|
val rows = data.collect()
|
||||||
|
rows.groupBy(getPartitionValues).foreach { case (partition, elements) =>
|
||||||
|
if (dataMap.contains(partition) && mode == "append") {
|
||||||
|
dataMap.put(partition, dataMap(partition) ++ elements)
|
||||||
|
} else if (dataMap.contains(partition)) {
|
||||||
|
throw new IllegalStateException("Partition was not removed properly")
|
||||||
|
} else {
|
||||||
|
dataMap.put(partition, elements)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.sources.v2.utils
|
||||||
|
|
||||||
|
import java.util
|
||||||
|
import java.util.concurrent.ConcurrentHashMap
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalog.v2.Identifier
|
||||||
|
import org.apache.spark.sql.catalog.v2.expressions.Transform
|
||||||
|
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
|
||||||
|
import org.apache.spark.sql.sources.v2.Table
|
||||||
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A V2SessionCatalog implementation that can be extended to generate arbitrary `Table` definitions
|
||||||
|
* for testing DDL as well as write operations (through df.write.saveAsTable, df.write.insertInto
|
||||||
|
* and SQL).
|
||||||
|
*/
|
||||||
|
private[v2] trait TestV2SessionCatalogBase[T <: Table] extends V2SessionCatalog {
|
||||||
|
|
||||||
|
protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
|
||||||
|
|
||||||
|
protected def newTable(
|
||||||
|
name: String,
|
||||||
|
schema: StructType,
|
||||||
|
partitions: Array[Transform],
|
||||||
|
properties: util.Map[String, String]): T
|
||||||
|
|
||||||
|
private def fullIdentifier(ident: Identifier): Identifier = {
|
||||||
|
if (ident.namespace().isEmpty) {
|
||||||
|
Identifier.of(Array("default"), ident.name())
|
||||||
|
} else {
|
||||||
|
ident
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def loadTable(ident: Identifier): Table = {
|
||||||
|
val fullIdent = fullIdentifier(ident)
|
||||||
|
if (tables.containsKey(fullIdent)) {
|
||||||
|
tables.get(fullIdent)
|
||||||
|
} else {
|
||||||
|
// Table was created through the built-in catalog
|
||||||
|
val t = super.loadTable(fullIdent)
|
||||||
|
val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties())
|
||||||
|
tables.put(fullIdent, table)
|
||||||
|
table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def createTable(
|
||||||
|
ident: Identifier,
|
||||||
|
schema: StructType,
|
||||||
|
partitions: Array[Transform],
|
||||||
|
properties: util.Map[String, String]): Table = {
|
||||||
|
val created = super.createTable(ident, schema, partitions, properties)
|
||||||
|
val t = newTable(created.name(), schema, partitions, properties)
|
||||||
|
val fullIdent = fullIdentifier(ident)
|
||||||
|
tables.put(fullIdent, t)
|
||||||
|
t
|
||||||
|
}
|
||||||
|
|
||||||
|
def clearTables(): Unit = {
|
||||||
|
assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?")
|
||||||
|
tables.keySet().asScala.foreach(super.dropTable)
|
||||||
|
tables.clear()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue