[SPARK-28554][SQL] Adds a v1 fallback writer implementation for v2 data source codepaths

## What changes were proposed in this pull request?

This PR adds a V1 fallback interface for writing to V2 Tables using V1 Writer interfaces. The only supported SaveMode that will be called on the target table will be an Append. The target table must use V2 interfaces such as `SupportsOverwrite` or `SupportsTruncate` to support Overwrite operations. It is up to the target DataSource implementation if this operation can be atomic or not.

We do not support dynamicPartitionOverwrite, as we cannot call a `commit` method that actually cleans up the data in the partitions that were touched through this fallback.

## How was this patch tested?

Will add tests and example implementation after comments + feedback. This is a proposal at this point.

Closes #25348 from brkyvz/v1WriteFallback.

Lead-authored-by: Burak Yavuz <brkyvz@gmail.com>
Co-authored-by: Burak Yavuz <burak@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Burak Yavuz 2019-08-21 17:25:25 +08:00 committed by Wenchen Fan
parent c4257b18a1
commit 4855bfe16b
9 changed files with 604 additions and 146 deletions

View file

@ -89,5 +89,14 @@ public enum TableCapability {
/**
* Signals that the table accepts input of any schema in a write operation.
*/
ACCEPT_ANY_SCHEMA
ACCEPT_ANY_SCHEMA,
/**
* Signals that the table supports append writes using the V1 InsertableRelation interface.
* <p>
* Tables that return this capability must create a V1WriteBuilder and may also support additional
* write modes, like {@link #TRUNCATE}, and {@link #OVERWRITE_BY_FILTER}, but cannot support
* {@link #OVERWRITE_DYNAMIC}.
*/
V1_BATCH_WRITE
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.v2
import java.util.UUID
import scala.collection.JavaConverters._
import scala.collection.mutable
@ -29,8 +31,10 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.sources
import org.apache.spark.sql.sources.v2.TableCapability
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder
import org.apache.spark.sql.util.CaseInsensitiveStringMap
object DataSourceV2Strategy extends Strategy with PredicateHelper {
@ -169,10 +173,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
catalog match {
case staging: StagingTableCatalog =>
AtomicCreateTableAsSelectExec(
staging, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil
staging, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil
case _ =>
CreateTableAsSelectExec(
catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil
catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil
}
case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
@ -191,6 +195,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
staging,
ident,
parts,
query,
planLater(query),
props,
writeOptions,
@ -200,6 +205,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
catalog,
ident,
parts,
query,
planLater(query),
props,
writeOptions,
@ -207,7 +213,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
}
case AppendData(r: DataSourceV2Relation, query, _) =>
AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
AppendDataExecV1(v1, r.options, query) :: Nil
case v2 =>
AppendDataExec(v2, r.options, planLater(query)) :: Nil
}
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
@ -215,9 +226,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse(
throw new AnalysisException(s"Cannot translate expression to source filter: $filter"))
}.toArray
OverwriteByExpressionExec(
r.table.asWritable, filters, r.options, planLater(query)) :: Nil
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil
case v2 =>
OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil
}
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) =>
OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil

View file

@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2
import java.util.UUID
import scala.collection.JavaConverters._
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SaveMode}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sources.{AlwaysTrue, CreatableRelationProvider, Filter, InsertableRelation}
import org.apache.spark.sql.sources.v2.{SupportsWrite, Table}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* Physical plan node for append into a v2 table using V1 write interfaces.
*
* Rows in the output data set are appended.
*/
case class AppendDataExecV1(
table: SupportsWrite,
writeOptions: CaseInsensitiveStringMap,
plan: LogicalPlan) extends V1FallbackWriters {
override protected def doExecute(): RDD[InternalRow] = {
writeWithV1(newWriteBuilder().buildForV1Write())
}
}
/**
* Physical plan node for overwrite into a v2 table with V1 write interfaces. Note that when this
* interface is used, the atomicity of the operation depends solely on the target data source.
*
* Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be
* deleted and rows in the output data set are appended.
*
* This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to
* truncate the table -- delete all rows -- and append the output data set. This uses the filter
* AlwaysTrue to delete all rows.
*/
case class OverwriteByExpressionExecV1(
table: SupportsWrite,
deleteWhere: Array[Filter],
writeOptions: CaseInsensitiveStringMap,
plan: LogicalPlan) extends V1FallbackWriters {
private def isTruncate(filters: Array[Filter]): Boolean = {
filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
}
override protected def doExecute(): RDD[InternalRow] = {
newWriteBuilder() match {
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
writeWithV1(builder.truncate().asV1Builder.buildForV1Write())
case builder: SupportsOverwrite =>
writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write())
case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table")
}
}
}
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
sealed trait V1FallbackWriters extends SupportsV1Write {
override def output: Seq[Attribute] = Nil
override final def children: Seq[SparkPlan] = Nil
def table: SupportsWrite
def writeOptions: CaseInsensitiveStringMap
protected implicit class toV1WriteBuilder(builder: WriteBuilder) {
def asV1Builder: V1WriteBuilder = builder match {
case v1: V1WriteBuilder => v1
case other => throw new IllegalStateException(
s"The returned writer ${other} was no longer a V1WriteBuilder.")
}
}
protected def newWriteBuilder(): V1WriteBuilder = {
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(plan.schema)
.withQueryId(UUID.randomUUID().toString)
writeBuilder.asV1Builder
}
}
/**
* A trait that allows Tables that use V1 Writer interfaces to append data.
*/
trait SupportsV1Write extends SparkPlan {
// TODO: We should be able to work on SparkPlans at this point.
def plan: LogicalPlan
protected def writeWithV1(relation: InsertableRelation): RDD[InternalRow] = {
relation.insert(Dataset.ofRows(sqlContext.sparkSession, plan), overwrite = false)
sparkContext.emptyRDD
}
}

View file

@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
import org.apache.spark.sql.sources.v2.{StagedTable, SupportsWrite}
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{LongAccumulator, Utils}
@ -63,10 +63,11 @@ case class CreateTableAsSelectExec(
catalog: TableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
ifNotExists: Boolean) extends V2TableWriteExec {
ifNotExists: Boolean) extends V2TableWriteExec with SupportsV1Write {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
@ -83,12 +84,14 @@ case class CreateTableAsSelectExec(
catalog.createTable(
ident, query.schema, partitioning.toArray, properties.asJava) match {
case table: SupportsWrite =>
val batchWrite = table.newWriteBuilder(writeOptions)
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withQueryId(UUID.randomUUID().toString)
.buildForBatch()
doWrite(batchWrite)
writeBuilder match {
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
case v2 => writeWithV2(v2.buildForBatch())
}
case _ =>
// table does not support writes
@ -114,6 +117,7 @@ case class AtomicCreateTableAsSelectExec(
catalog: StagingTableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
@ -147,10 +151,11 @@ case class ReplaceTableAsSelectExec(
catalog: TableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean) extends AtomicTableWriteExec {
orCreate: Boolean) extends V2TableWriteExec with SupportsV1Write {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
@ -173,12 +178,14 @@ case class ReplaceTableAsSelectExec(
Utils.tryWithSafeFinallyAndFailureCallbacks({
createdTable match {
case table: SupportsWrite =>
val batchWrite = table.newWriteBuilder(writeOptions)
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withQueryId(UUID.randomUUID().toString)
.buildForBatch()
doWrite(batchWrite)
writeBuilder match {
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
case v2 => writeWithV2(v2.buildForBatch())
}
case _ =>
// table does not support writes
@ -207,6 +214,7 @@ case class AtomicReplaceTableAsSelectExec(
catalog: StagingTableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
@ -242,8 +250,7 @@ case class AppendDataExec(
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
override protected def doExecute(): RDD[InternalRow] = {
val batchWrite = newWriteBuilder().buildForBatch()
doWrite(batchWrite)
writeWithV2(newWriteBuilder().buildForBatch())
}
}
@ -268,18 +275,16 @@ case class OverwriteByExpressionExec(
}
override protected def doExecute(): RDD[InternalRow] = {
val batchWrite = newWriteBuilder() match {
newWriteBuilder() match {
case builder: SupportsTruncate if isTruncate(deleteWhere) =>
builder.truncate().buildForBatch()
writeWithV2(builder.truncate().buildForBatch())
case builder: SupportsOverwrite =>
builder.overwrite(deleteWhere).buildForBatch()
writeWithV2(builder.overwrite(deleteWhere).buildForBatch())
case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table")
}
doWrite(batchWrite)
}
}
@ -298,15 +303,13 @@ case class OverwritePartitionsDynamicExec(
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
override protected def doExecute(): RDD[InternalRow] = {
val batchWrite = newWriteBuilder() match {
newWriteBuilder() match {
case builder: SupportsDynamicOverwrite =>
builder.overwriteDynamicPartitions().buildForBatch()
writeWithV2(builder.overwriteDynamicPartitions().buildForBatch())
case _ =>
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
}
doWrite(batchWrite)
}
}
@ -317,7 +320,7 @@ case class WriteToDataSourceV2Exec(
def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()
override protected def doExecute(): RDD[InternalRow] = {
doWrite(batchWrite)
writeWithV2(batchWrite)
}
}
@ -331,8 +334,8 @@ trait BatchWriteHelper {
def newWriteBuilder(): WriteBuilder = {
table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withQueryId(UUID.randomUUID().toString)
.withInputDataSchema(query.schema)
.withQueryId(UUID.randomUUID().toString)
}
}
@ -347,7 +350,7 @@ trait V2TableWriteExec extends UnaryExecNode {
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = {
protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
val writerFactory = batchWrite.createBatchWriterFactory()
val useCommitCoordinator = batchWrite.useCommitCoordinator
val rdd = query.execute()
@ -463,7 +466,7 @@ object DataWritingSparkTask extends Logging {
}
}
private[v2] trait AtomicTableWriteExec extends V2TableWriteExec {
private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1Write {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
protected def writeToStagedTable(
@ -473,14 +476,17 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec {
Utils.tryWithSafeFinallyAndFailureCallbacks({
stagedTable match {
case table: SupportsWrite =>
val batchWrite = table.newWriteBuilder(writeOptions)
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withQueryId(UUID.randomUUID().toString)
.buildForBatch()
val writtenRows = doWrite(batchWrite)
val writtenRows = writeBuilder match {
case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write())
case v2 => writeWithV2(v2.buildForBatch())
}
stagedTable.commitStagedChanges()
writtenRows
case _ =>
// Table does not support writes - staged changes are also rolled back below.
throw new SparkException(

View file

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources.v2.writer
import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
/**
* A trait that should be implemented by V1 DataSources that would like to leverage the DataSource
* V2 write code paths. The InsertableRelation will be used only to Append data. Other
* instances of the [[WriteBuilder]] interface such as [[SupportsOverwrite]], [[SupportsTruncate]]
* should be extended as well to support additional operations other than data appends.
*
* This interface is designed to provide Spark DataSources time to migrate to DataSource V2 and
* will be removed in a future Spark release.
*
* @since 3.0.0
*/
@Experimental
@Unstable
trait V1WriteBuilder extends WriteBuilder {
/**
* Creates an InsertableRelation that allows appending a DataFrame to a
* a destination (using data source-specific parameters). The insert method will only be
* called with `overwrite=false`. The DataSource should implement the overwrite behavior as
* part of the [[SupportsOverwrite]], and [[SupportsTruncate]] interfaces.
*
* @since 3.0.0
*/
def buildForV1Write(): InsertableRelation
// These methods cannot be implemented by a V1WriteBuilder. The super class will throw
// an Unsupported OperationException
override final def buildForBatch(): BatchWrite = super.buildForBatch()
override final def buildForStreaming(): StreamingWrite = super.buildForStreaming()
}

View file

@ -31,39 +31,90 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class DataSourceV2DataFrameSessionCatalogSuite
extends SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] {
test("saveAsTable: Append mode should not fail if the table already exists " +
"and a same-name temp view exist") {
withTable("same_name") {
withTempView("same_name") {
val format = spark.sessionState.conf.defaultDataSourceName
sql(s"CREATE TABLE same_name(id LONG) USING $format")
spark.range(10).createTempView("same_name")
spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name")
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
}
}
}
test("saveAsTable with mode Overwrite should not fail if the table already exists " +
"and a same-name temp view exist") {
withTable("same_name") {
withTempView("same_name") {
sql(s"CREATE TABLE same_name(id LONG) USING $v2Format")
spark.range(10).createTempView("same_name")
spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name")
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
}
}
}
}
class InMemoryTableProvider extends TableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new UnsupportedOperationException("D'oh!")
}
}
class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] {
override def newTable(
name: String,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): InMemoryTable = {
new InMemoryTable(name, schema, partitions, properties)
}
}
private[v2] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]]
extends QueryTest
with SharedSparkSession
with BeforeAndAfter {
import testImplicits._
private def catalog(name: String): CatalogPlugin = {
protected def catalog(name: String): CatalogPlugin = {
spark.sessionState.catalogManager.catalog(name)
}
private val v2Format = classOf[InMemoryTableProvider].getName
protected val v2Format = classOf[InMemoryTableProvider].getName
protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName
before {
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[TestV2SessionCatalog].getName)
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, catalogClassName)
}
override def afterEach(): Unit = {
super.afterEach()
catalog("session").asInstanceOf[TestV2SessionCatalog].clearTables()
catalog("session").asInstanceOf[Catalog].clearTables()
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName)
}
private def verifyTable(tableName: String, expected: DataFrame): Unit = {
protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
checkAnswer(spark.table(tableName), expected)
checkAnswer(sql(s"SELECT * FROM $tableName"), expected)
checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected)
checkAnswer(sql(s"TABLE $tableName"), expected)
}
import testImplicits._
test("saveAsTable: v2 table - table doesn't exist and default mode (ErrorIfExists)") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
@ -90,20 +141,6 @@ class DataSourceV2DataFrameSessionCatalogSuite
}
}
test("saveAsTable: Append mode should not fail if the table already exists " +
"and a same-name temp view exist") {
withTable("same_name") {
withTempView("same_name") {
val format = spark.sessionState.conf.defaultDataSourceName
sql(s"CREATE TABLE same_name(id LONG) USING $format")
spark.range(10).createTempView("same_name")
spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name")
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
}
}
}
test("saveAsTable: v2 table - table exists") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
@ -147,19 +184,6 @@ class DataSourceV2DataFrameSessionCatalogSuite
}
}
test("saveAsTable with mode Overwrite should not fail if the table already exists " +
"and a same-name temp view exist") {
withTable("same_name") {
withTempView("same_name") {
sql(s"CREATE TABLE same_name(id LONG) USING $v2Format")
spark.range(10).createTempView("same_name")
spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name")
checkAnswer(spark.table("same_name"), spark.range(10).toDF())
checkAnswer(spark.table("default.same_name"), spark.range(20).toDF())
}
}
}
test("saveAsTable: v2 table - ignore mode and table doesn't exist") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
@ -175,55 +199,3 @@ class DataSourceV2DataFrameSessionCatalogSuite
verifyTable(t1, Seq(("c", "d")).toDF("id", "data"))
}
}
class InMemoryTableProvider extends TableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new UnsupportedOperationException("D'oh!")
}
}
/** A SessionCatalog that always loads an in memory Table, so we can test write code paths. */
class TestV2SessionCatalog extends V2SessionCatalog {
protected val tables: util.Map[Identifier, InMemoryTable] =
new ConcurrentHashMap[Identifier, InMemoryTable]()
private def fullIdentifier(ident: Identifier): Identifier = {
if (ident.namespace().isEmpty) {
Identifier.of(Array("default"), ident.name())
} else {
ident
}
}
override def loadTable(ident: Identifier): Table = {
val fullIdent = fullIdentifier(ident)
if (tables.containsKey(fullIdent)) {
tables.get(fullIdent)
} else {
// Table was created through the built-in catalog
val t = super.loadTable(fullIdent)
val table = new InMemoryTable(t.name(), t.schema(), t.partitioning(), t.properties())
tables.put(fullIdent, table)
table
}
}
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
val created = super.createTable(ident, schema, partitions, properties)
val t = new InMemoryTable(created.name(), schema, partitions, properties)
val fullIdent = fullIdentifier(ident)
tables.put(fullIdent, t)
t
}
def clearTables(): Unit = {
assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?")
tables.keySet().asScala.foreach(super.dropTable)
tables.clear()
}
}

View file

@ -234,7 +234,8 @@ class InMemoryTable(
private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap --= deletesKeys(filters)
val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
dataMap --= deleteKeys
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
@ -247,7 +248,37 @@ class InMemoryTable(
}
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
dataMap --= deletesKeys(filters)
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
}
}
object InMemoryTable {
def filtersToKeys(
keys: Iterable[Seq[Any]],
partitionNames: Seq[String],
filters: Array[Filter]): Iterable[Seq[Any]] = {
keys.filter { partValues =>
filters.flatMap(splitAnd).forall {
case EqualTo(attr, value) =>
value == extractValue(attr, partitionNames, partValues)
case IsNotNull(attr) =>
null != extractValue(attr, partitionNames, partValues)
case f =>
throw new IllegalArgumentException(s"Unsupported filter type: $f")
}
}
}
private def extractValue(
attr: String,
partFieldNames: Seq[String],
partValues: Seq[Any]): Any = {
partFieldNames.zipWithIndex.find(_._1 == attr) match {
case Some((_, partIndex)) =>
partValues(partIndex)
case _ =>
throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
}
}
private def splitAnd(filter: Filter): Seq[Filter] = {
@ -256,30 +287,6 @@ class InMemoryTable(
case _ => filter :: Nil
}
}
private def deletesKeys(filters: Array[Filter]): Iterable[Seq[Any]] = {
dataMap.synchronized {
dataMap.keys.filter { partValues =>
filters.flatMap(splitAnd).forall {
case EqualTo(attr, value) =>
value == extractValue(attr, partValues)
case IsNotNull(attr) =>
null != extractValue(attr, partValues)
case f =>
throw new IllegalArgumentException(s"Unsupported filter type: $f")
}
}
}
}
private def extractValue(attr: String, partValues: Seq[Any]): Any = {
partFieldNames.zipWithIndex.find(_._1 == attr) match {
case Some((_, partIndex)) =>
partValues(partIndex)
case _ =>
throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
}
}
}
object TestInMemoryTableCatalog {

View file

@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources.v2
import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation}
import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase
import org.apache.spark.sql.sources.v2.writer.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with BeforeAndAfter {
import testImplicits._
private val v2Format = classOf[InMemoryV1Provider].getName
override def beforeAll(): Unit = {
super.beforeAll()
InMemoryV1Provider.clear()
}
override def afterEach(): Unit = {
super.afterEach()
InMemoryV1Provider.clear()
}
test("append fallback") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode("append").option("name", "t1").format(v2Format).save()
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
df.write.mode("append").option("name", "t1").format(v2Format).save()
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df))
}
test("overwrite by truncate fallback") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode("append").option("name", "t1").format(v2Format).save()
val df2 = Seq((10, "k"), (20, "l"), (30, "m")).toDF("a", "b")
df2.write.mode("overwrite").option("name", "t1").format(v2Format).save()
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2)
}
}
class V1WriteFallbackSessionCatalogSuite
extends SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] {
override protected val v2Format = classOf[InMemoryV1Provider].getName
override protected val catalogClassName: String = classOf[V1FallbackTableCatalog].getName
override protected def verifyTable(tableName: String, expected: DataFrame): Unit = {
checkAnswer(InMemoryV1Provider.getTableData(spark, s"default.$tableName"), expected)
}
}
class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV1Fallback] {
override def newTable(
name: String,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
InMemoryV1Provider.tables.put(name, t)
t
}
}
private object InMemoryV1Provider {
val tables: mutable.Map[String, InMemoryTableWithV1Fallback] = mutable.Map.empty
def getTableData(spark: SparkSession, name: String): DataFrame = {
val t = tables.getOrElse(name, throw new IllegalArgumentException(s"Table $name doesn't exist"))
spark.createDataFrame(t.getData.asJava, t.schema)
}
def clear(): Unit = {
tables.clear()
}
}
class InMemoryV1Provider extends TableProvider with DataSourceRegister {
override def getTable(options: CaseInsensitiveStringMap): Table = {
InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {
new InMemoryTableWithV1Fallback(
"InMemoryTableWithV1Fallback",
new StructType().add("a", IntegerType).add("b", StringType),
Array(IdentityTransform(FieldReference(Seq("a")))),
options.asCaseSensitiveMap()
)
})
}
override def shortName(): String = "in-memory"
}
class InMemoryTableWithV1Fallback(
override val name: String,
override val schema: StructType,
override val partitioning: Array[Transform],
override val properties: util.Map[String, String]) extends Table with SupportsWrite {
partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform]) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
}
}
override def capabilities: util.Set[TableCapability] = Set(
TableCapability.BATCH_WRITE,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.TRUNCATE).asJava
@volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty
private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames)
private val partIndexes = partFieldNames.map(schema.fieldIndex(_))
def getData: Seq[Row] = dataMap.values.flatten.toSeq
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
new FallbackWriteBuilder(options)
}
private class FallbackWriteBuilder(options: CaseInsensitiveStringMap)
extends WriteBuilder
with V1WriteBuilder
with SupportsTruncate
with SupportsOverwrite {
private var mode = "append"
override def truncate(): WriteBuilder = {
dataMap.clear()
mode = "truncate"
this
}
override def overwrite(filters: Array[Filter]): WriteBuilder = {
val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
dataMap --= keys
mode = "overwrite"
this
}
private def getPartitionValues(row: Row): Seq[Any] = {
partIndexes.map(row.get)
}
override def buildForV1Write(): InsertableRelation = {
new InsertableRelation {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true")
val rows = data.collect()
rows.groupBy(getPartitionValues).foreach { case (partition, elements) =>
if (dataMap.contains(partition) && mode == "append") {
dataMap.put(partition, dataMap(partition) ++ elements)
} else if (dataMap.contains(partition)) {
throw new IllegalStateException("Partition was not removed properly")
} else {
dataMap.put(partition, elements)
}
}
}
}
}
}
}

View file

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources.v2.utils
import java.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
import org.apache.spark.sql.catalog.v2.Identifier
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
import org.apache.spark.sql.sources.v2.Table
import org.apache.spark.sql.types.StructType
/**
* A V2SessionCatalog implementation that can be extended to generate arbitrary `Table` definitions
* for testing DDL as well as write operations (through df.write.saveAsTable, df.write.insertInto
* and SQL).
*/
private[v2] trait TestV2SessionCatalogBase[T <: Table] extends V2SessionCatalog {
protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
protected def newTable(
name: String,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): T
private def fullIdentifier(ident: Identifier): Identifier = {
if (ident.namespace().isEmpty) {
Identifier.of(Array("default"), ident.name())
} else {
ident
}
}
override def loadTable(ident: Identifier): Table = {
val fullIdent = fullIdentifier(ident)
if (tables.containsKey(fullIdent)) {
tables.get(fullIdent)
} else {
// Table was created through the built-in catalog
val t = super.loadTable(fullIdent)
val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties())
tables.put(fullIdent, table)
table
}
}
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
val created = super.createTable(ident, schema, partitions, properties)
val t = newTable(created.name(), schema, partitions, properties)
val fullIdent = fullIdentifier(ident)
tables.put(fullIdent, t)
t
}
def clearTables(): Unit = {
assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?")
tables.keySet().asScala.foreach(super.dropTable)
tables.clear()
}
}