diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 051783b63e..4a3f9c4f77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,9 +24,8 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange} +import org.apache.spark.sql.catalog.v2._ import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform} -import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, DescribeTableStatement, InsertIntoStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -641,21 +640,13 @@ class Analyzer( * [[ResolveRelations]] still resolves v1 tables. */ object ResolveTables extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) - if catalog.isTemporaryTable(ident) => - u // temporary views take precedence over catalog table names - - case u @ UnresolvedRelation(CatalogObjectIdentifier(maybeCatalog, ident)) => - maybeCatalog.orElse(sessionCatalog) - .flatMap(loadTable(_, ident)) - .map { - case unresolved: UnresolvedTable => u - case resolved => DataSourceV2Relation.create(resolved) - } - .getOrElse(u) + case u: UnresolvedRelation => + val v2TableOpt = lookupV2Relation(u.multipartIdentifier) match { + case scala.Left((_, _, tableOpt)) => tableOpt + case scala.Right(tableOpt) => tableOpt + } + v2TableOpt.map(DataSourceV2Relation.create).getOrElse(u) } } @@ -770,40 +761,41 @@ class Analyzer( object ResolveInsertInto extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoStatement( - UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _) - if i.query.resolved => - loadTable(tableCatalog, ident) - .map(DataSourceV2Relation.create) - .map(relation => { - // ifPartitionNotExists is append with validation, but validation is not supported - if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") - } + case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved => + lookupV2Relation(u.multipartIdentifier) match { + case scala.Left((_, _, Some(v2Table: Table))) => + resolveV2Insert(i, v2Table) + case scala.Right(Some(v2Table: Table)) => + resolveV2Insert(i, v2Table) + case _ => + InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) + } + } - val partCols = partitionColumnNames(relation.table) - validatePartitionSpec(partCols, i.partitionSpec) + private def resolveV2Insert(i: InsertIntoStatement, table: Table): LogicalPlan = { + val relation = DataSourceV2Relation.create(table) + // ifPartitionNotExists is append with validation, but validation is not supported + if (i.ifPartitionNotExists) { + throw new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") + } - val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) - val query = addStaticPartitionColumns(relation, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + val partCols = partitionColumnNames(relation.table) + validatePartitionSpec(partCols, i.partitionSpec) - if (!i.overwrite) { - AppendData.byPosition(relation, query) - } else if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(relation, query) - } else { - OverwriteByExpression.byPosition( - relation, query, staticDeleteExpression(relation, staticPartitions)) - } - }) - .getOrElse(i) + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val query = addStaticPartitionColumns(relation, i.query, staticPartitions) + val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && + conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC - case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) - if i.query.resolved => - InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) + if (!i.overwrite) { + AppendData.byPosition(relation, query) + } else if (dynamicPartitionOverwrite) { + OverwritePartitionsDynamic.byPosition(relation, query) + } else { + OverwriteByExpression.byPosition( + relation, query, staticDeleteExpression(relation, staticPartitions)) + } } private def partitionColumnNames(table: Table): Seq[String] = { @@ -2773,6 +2765,39 @@ class Analyzer( } } } + + /** + * Performs the lookup of DataSourceV2 Tables. The order of resolution is: + * 1. Check if this relation is a temporary table + * 2. Check if it has a catalog identifier. Here we try to load the table. If we find the table, + * we can return the table. The result returned by an explicit catalog will be returned on + * the Left projection of the Either. + * 3. Try resolving the relation using the V2SessionCatalog if that is defined. If the + * V2SessionCatalog returns a V1 table definition (UnresolvedTable), then we return a `None` + * on the right side so that we can fallback to the V1 code paths. + * The basic idea is, if a value is returned on the Left, it means a v2 catalog is defined and + * must be used to resolve the table. If a value is returned on the right, then we can try + * creating a V2 relation if a V2 Table is defined. If it isn't defined, then we should defer + * to V1 code paths. + */ + private def lookupV2Relation( + identifier: Seq[String] + ): Either[(CatalogPlugin, Identifier, Option[Table]), Option[Table]] = { + import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ + + identifier match { + case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => + scala.Right(None) + case CatalogObjectIdentifier(Some(v2Catalog), ident) => + scala.Left((v2Catalog, ident, loadTable(v2Catalog, ident))) + case CatalogObjectIdentifier(None, ident) => + catalogManager.v2SessionCatalog.flatMap(loadTable(_, ident)) match { + case Some(_: UnresolvedTable) => scala.Right(None) + case other => scala.Right(other) + } + case _ => scala.Right(None) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bd54c66992..920ca3e473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.SchemaUtils /** * Throws user facing errors when passed invalid queries that fail to analyze. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 0b49cf24e6..4b22196588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -367,10 +367,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { ) } - df.sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { + val session = df.sparkSession + val provider = DataSource.lookupDataSource(source, session.sessionState.conf) + val canUseV2 = canUseV2Source(session, provider) + val sessionCatalogOpt = session.sessionState.analyzer.sessionCatalog + + session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case CatalogObjectIdentifier(Some(catalog), ident) => insertInto(catalog, ident) - // TODO(SPARK-28667): Support the V2SessionCatalog + + case CatalogObjectIdentifier(None, ident) + if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 => + insertInto(sessionCatalogOpt.get, ident) + case AsTableIdentifier(tableIdentifier) => insertInto(tableIdentifier) case other => @@ -382,7 +391,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - val table = DataSourceV2Relation.create(catalog.asTableCatalog.loadTable(ident)) + val table = catalog.asTableCatalog.loadTable(ident) match { + case _: UnresolvedTable => + return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) + case t => + DataSourceV2Relation.create(t) + } val command = modeForDSV2 match { case SaveMode.Append => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala index 22ebfeea04..61a01cb722 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -18,26 +18,43 @@ package org.apache.spark.sql.sources.v2 import java.util -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, SaveMode} +import org.apache.spark.sql.catalog.v2.CatalogPlugin import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2DataFrameSessionCatalogSuite - extends SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] { + extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) + with SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] { + + import testImplicits._ + + override protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { + val dfw = insert.write.format(v2Format) + if (mode != null) { + dfw.mode(mode) + } + dfw.insertInto(tableName) + } + + override protected def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(spark.table(tableName), expected) + checkAnswer(sql(s"SELECT * FROM $tableName"), expected) + checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected) + checkAnswer(sql(s"TABLE $tableName"), expected) + } + + override protected val catalogAndNamespace: String = "" test("saveAsTable: Append mode should not fail if the table already exists " + "and a same-name temp view exist") { @@ -97,21 +114,16 @@ private[v2] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalog protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName before { - spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, catalogClassName) + spark.conf.set(V2_SESSION_CATALOG.key, catalogClassName) } override def afterEach(): Unit = { super.afterEach() catalog("session").asInstanceOf[Catalog].clearTables() - spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName) + spark.conf.set(V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName) } - protected def verifyTable(tableName: String, expected: DataFrame): Unit = { - checkAnswer(spark.table(tableName), expected) - checkAnswer(sql(s"SELECT * FROM $tableName"), expected) - checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected) - checkAnswer(sql(s"TABLE $tableName"), expected) - } + protected def verifyTable(tableName: String, expected: DataFrame): Unit import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala index af9e56a3b9..d544882f39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala @@ -17,13 +17,10 @@ package org.apache.spark.sql.sources.v2 -import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.{DataFrame, Row, SaveMode} -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} -import org.apache.spark.sql.test.SharedSparkSession - -class DataSourceV2DataFrameSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { +class DataSourceV2DataFrameSuite + extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) { import testImplicits._ before { @@ -31,25 +28,24 @@ class DataSourceV2DataFrameSuite extends QueryTest with SharedSparkSession with spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName) } - test("insertInto: append") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - df.write.insertInto(t1) - checkAnswer(spark.table(t1), df) - } + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() } - test("insertInto: append by position") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") - dfr.write.insertInto(t1) - checkAnswer(spark.table(t1), df) + override protected val catalogAndNamespace: String = "testcat.ns1.ns2.tbls" + override protected val v2Format: String = classOf[FakeV2Provider].getName + + override def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(spark.table(tableName), expected) + } + + override protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { + val dfw = insert.write.format(v2Format) + if (mode != null) { + dfw.mode(mode) } + dfw.insertInto(tableName) } test("insertInto: append across catalog") { @@ -65,83 +61,6 @@ class DataSourceV2DataFrameSuite extends QueryTest with SharedSparkSession with } } - test("insertInto: append partitioned table") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - df.write.insertInto(t1) - checkAnswer(spark.table(t1), df) - } - } - - test("insertInto: overwrite non-partitioned table") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - val df2 = Seq((4L, "d"), (5L, "e"), (6L, "f")).toDF("id", "data") - df.write.insertInto(t1) - df2.write.mode("overwrite").insertInto(t1) - checkAnswer(spark.table(t1), df2) - } - } - - test("insertInto: overwrite partitioned table in static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1) - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - df.write.mode("overwrite").insertInto(t1) - checkAnswer(spark.table(t1), df) - } - } - } - - - test("insertInto: overwrite partitioned table in static mode by position") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1) - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") - dfr.write.mode("overwrite").insertInto(t1) - checkAnswer(spark.table(t1), df) - } - } - } - - test("insertInto: overwrite partitioned table in dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1) - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - df.write.mode("overwrite").insertInto(t1) - checkAnswer(spark.table(t1), df.union(sql("SELECT 4L, 'keep'"))) - } - } - } - - test("insertInto: overwrite partitioned table in dynamic mode by position") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1) - val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") - val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") - dfr.write.mode("overwrite").insertInto(t1) - checkAnswer(spark.table(t1), df.union(sql("SELECT 4L, 'keep'"))) - } - } - } - testQuietly("saveAsTable: table doesn't exist => create table") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSessionCatalogSuite.scala new file mode 100644 index 0000000000..d7f8c373f8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSessionCatalogSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} + +class DataSourceV2SQLSessionCatalogSuite + extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) + with SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] { + + import testImplicits._ + + override protected val catalogAndNamespace = "" + + override protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { + val tmpView = "tmp_view" + withTempView(tmpView) { + insert.createOrReplaceTempView(tmpView) + val overwrite = if (mode == SaveMode.Overwrite) "OVERWRITE" else "INTO" + sql(s"INSERT $overwrite TABLE $tableName SELECT * FROM $tmpView") + } + } + + override protected def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(spark.table(tableName), expected) + checkAnswer(sql(s"SELECT * FROM $tableName"), expected) + checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected) + checkAnswer(sql(s"TABLE $tableName"), expected) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 7cc5679c43..965e7006d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -19,32 +19,45 @@ package org.apache.spark.sql.sources.v2 import scala.collection.JavaConverters._ -import org.scalatest.BeforeAndAfter - import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql._ import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, TableCatalog} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG import org.apache.spark.sql.sources.v2.internal.UnresolvedTable -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BooleanType, DoubleType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { +class DataSourceV2SQLSuite + extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ private val orc2 = classOf[OrcDataSourceV2].getName private val v2Source = classOf[FakeV2Provider].getName + override protected val v2Format = v2Source + override protected val catalogAndNamespace = "testcat.ns1.ns2." private def catalog(name: String): CatalogPlugin = { spark.sessionState.catalogManager.catalog(name) } + protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { + val tmpView = "tmp_view" + withTempView(tmpView) { + insert.createOrReplaceTempView(tmpView) + val overwrite = if (mode == SaveMode.Overwrite) "OVERWRITE" else "INTO" + sql(s"INSERT $overwrite TABLE $tableName SELECT * FROM $tmpView") + } + } + + override def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(spark.table(tableName), expected) + } + before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) spark.conf.set( @@ -1432,15 +1445,6 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before } } - test("InsertInto: append") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") - sql(s"INSERT INTO $t1 SELECT id, data FROM source") - checkAnswer(spark.table(t1), spark.table("source")) - } - } - test("InsertInto: append - across catalog") { val t1 = "testcat.ns1.ns2.tbl" val t2 = "testcat2.db.tbl" @@ -1452,283 +1456,6 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before } } - test("InsertInto: append to partitioned table - without PARTITION clause") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO TABLE $t1 SELECT * FROM source") - checkAnswer(spark.table(t1), spark.table("source")) - } - } - - test("InsertInto: append to partitioned table - with PARTITION clause") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO TABLE $t1 PARTITION (id) SELECT * FROM source") - checkAnswer(spark.table(t1), spark.table("source")) - } - } - - test("InsertInto: dynamic PARTITION clause fails with non-partition column") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - - val exc = intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $t1 PARTITION (data) SELECT * FROM source") - } - - assert(spark.table(t1).count === 0) - assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) - assert(exc.getMessage.contains("data")) - } - } - - test("InsertInto: static PARTITION clause fails with non-partition column") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (data)") - - val exc = intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $t1 PARTITION (id=1) SELECT data FROM source") - } - - assert(spark.table(t1).count === 0) - assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) - assert(exc.getMessage.contains("id")) - } - } - - test("InsertInto: fails when missing a column") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, missing string) USING foo") - val exc = intercept[AnalysisException] { - sql(s"INSERT INTO $t1 SELECT id, data FROM source") - } - - assert(spark.table(t1).count === 0) - assert(exc.getMessage.contains(s"Cannot write to '$t1', not enough data columns")) - } - } - - test("InsertInto: fails when an extra column is present") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") - val exc = intercept[AnalysisException] { - sql(s"INSERT INTO $t1 SELECT id, data, 'fruit' FROM source") - } - - assert(spark.table(t1).count === 0) - assert(exc.getMessage.contains(s"Cannot write to '$t1', too many data columns")) - } - } - - test("InsertInto: append to partitioned table - static clause") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO $t1 PARTITION (id = 23) SELECT data FROM source") - checkAnswer(spark.table(t1), sql("SELECT 23, data FROM source")) - } - } - - test("InsertInto: overwrite non-partitioned table") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 USING foo AS SELECT * FROM source") - sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source2") - checkAnswer(spark.table(t1), spark.table("source2")) - } - } - - test("InsertInto: overwrite - dynamic clause - static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"))) - } - } - } - - test("InsertInto: overwrite - dynamic clause - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"), - Row(4, "keep"))) - } - } - } - - test("InsertInto: overwrite - missing clause - static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") - sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"))) - } - } - } - - test("InsertInto: overwrite - missing clause - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") - sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a"), - Row(2, "b"), - Row(3, "c"), - Row(4, "keep"))) - } - } - } - - test("InsertInto: overwrite - static clause") { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) USING foo PARTITIONED BY (p1)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 23), - Row(2, "b", 23), - Row(3, "c", 23), - Row(4, "keep", 2))) - } - } - - test("InsertInto: overwrite - mixed clause - static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2))) - } - } - } - - test("InsertInto: overwrite - mixed clause reordered - static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2))) - } - } - } - - test("InsertInto: overwrite - implicit dynamic partition - static mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2))) - } - } - } - - test("InsertInto: overwrite - mixed clause - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2), - Row(4, "keep", 2))) - } - } - } - - test("InsertInto: overwrite - mixed clause reordered - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2), - Row(4, "keep", 2))) - } - } - } - - test("InsertInto: overwrite - implicit dynamic partition - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") - checkAnswer(spark.table(t1), Seq( - Row(1, "a", 2), - Row(2, "b", 2), - Row(3, "c", 2), - Row(4, "keep", 2))) - } - } - } - - test("InsertInto: overwrite - multiple static partitions - dynamic mode") { - withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { - val t1 = "testcat.ns1.ns2.tbl" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM source") - checkAnswer(spark.table(t1), Seq( - Row(2, "a", 2), - Row(2, "b", 2), - Row(2, "c", 2), - Row(4, "keep", 2))) - } - } - } - test("ShowTables: using v2 catalog") { spark.sql("CREATE TABLE testcat.db.table_name (id bigint, data string) USING foo") spark.sql("CREATE TABLE testcat.n1.n2.db.table_name (id bigint, data string) USING foo") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/InsertIntoTests.scala new file mode 100644 index 0000000000..8cc1b320cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/InsertIntoTests.scala @@ -0,0 +1,467 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * A collection of "INSERT INTO" tests that can be run through the SQL or DataFrameWriter APIs. + * Extending test suites can implement the `doInsert` method to run the insert through either + * API. + * + * @param supportsDynamicOverwrite Whether the Table implementations used in the test suite support + * dynamic partition overwrites. If they do, we will check for the + * success of the operations. If not, then we will check that we + * failed with the right error message. + * @param includeSQLOnlyTests Certain INSERT INTO behavior can be achieved purely through SQL, e.g. + * static or dynamic partition overwrites. This flag should be set to + * true if we would like to test these cases. + */ +abstract class InsertIntoTests( + override protected val supportsDynamicOverwrite: Boolean, + override protected val includeSQLOnlyTests: Boolean) extends InsertIntoSQLOnlyTests { + + import testImplicits._ + + /** + * Insert data into a table using the insertInto statement. Implementations can be in SQL + * ("INSERT") or using the DataFrameWriter (`df.write.insertInto`). + */ + protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode = null): Unit + + test("insertInto: append") { + val t1 = s"${catalogAndNamespace}tbl" + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + doInsert(t1, df) + verifyTable(t1, df) + } + + test("insertInto: append by position") { + val t1 = s"${catalogAndNamespace}tbl" + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") + + doInsert(t1, dfr) + verifyTable(t1, df) + } + + test("insertInto: append partitioned table") { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + doInsert(t1, df) + verifyTable(t1, df) + } + } + + test("insertInto: overwrite non-partitioned table") { + val t1 = s"${catalogAndNamespace}tbl" + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + val df2 = Seq((4L, "d"), (5L, "e"), (6L, "f")).toDF("id", "data") + doInsert(t1, df) + doInsert(t1, df2, SaveMode.Overwrite) + verifyTable(t1, df2) + } + + test("insertInto: overwrite partitioned table in static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") + doInsert(t1, init) + + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + doInsert(t1, df, SaveMode.Overwrite) + verifyTable(t1, df) + } + } + + + test("insertInto: overwrite partitioned table in static mode by position") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") + doInsert(t1, init) + + val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") + doInsert(t1, dfr, SaveMode.Overwrite) + + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + verifyTable(t1, df) + } + } + } + + test("insertInto: fails when missing a column") { + val t1 = s"${catalogAndNamespace}tbl" + sql(s"CREATE TABLE $t1 (id bigint, data string, missing string) USING $v2Format") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + val exc = intercept[AnalysisException] { + doInsert(t1, df) + } + + verifyTable(t1, Seq.empty[(Long, String, String)].toDF("id", "data", "missing")) + val tableName = if (catalogAndNamespace.isEmpty) s"default.$t1" else t1 + assert(exc.getMessage.contains(s"Cannot write to '$tableName', not enough data columns")) + } + + test("insertInto: fails when an extra column is present") { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = Seq((1L, "a", "mango")).toDF("id", "data", "fruit") + val exc = intercept[AnalysisException] { + doInsert(t1, df) + } + + verifyTable(t1, Seq.empty[(Long, String)].toDF("id", "data")) + val tableName = if (catalogAndNamespace.isEmpty) s"default.$t1" else t1 + assert(exc.getMessage.contains(s"Cannot write to '$tableName', too many data columns")) + } + } + + dynamicOverwriteTest("insertInto: overwrite partitioned table in dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") + doInsert(t1, init) + + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + doInsert(t1, df, SaveMode.Overwrite) + + verifyTable(t1, df.union(sql("SELECT 4L, 'keep'"))) + } + } + + dynamicOverwriteTest("insertInto: overwrite partitioned table in dynamic mode by position") { + val t1 = s"${catalogAndNamespace}tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") + doInsert(t1, init) + + val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") + doInsert(t1, dfr, SaveMode.Overwrite) + + val df = Seq((1L, "a"), (2L, "b"), (3L, "c"), (4L, "keep")).toDF("id", "data") + verifyTable(t1, df) + } + } +} + +private[v2] trait InsertIntoSQLOnlyTests + extends QueryTest + with SharedSparkSession + with BeforeAndAfter { + + import testImplicits._ + + /** Check that the results in `tableName` match the `expected` DataFrame. */ + protected def verifyTable(tableName: String, expected: DataFrame): Unit + + protected val v2Format: String + protected val catalogAndNamespace: String + + /** + * Whether dynamic partition overwrites are supported by the `Table` definitions used in the + * test suites. Tables that leverage the V1 Write interface do not support dynamic partition + * overwrites. + */ + protected val supportsDynamicOverwrite: Boolean + + /** Whether to include the SQL specific tests in this trait within the extending test suite. */ + protected val includeSQLOnlyTests: Boolean + + private def withTableAndData(tableName: String)(testFn: String => Unit): Unit = { + withTable(tableName) { + val viewName = "tmp_view" + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView(viewName) + withTempView(viewName) { + testFn(viewName) + } + } + } + + protected def dynamicOverwriteTest(testName: String)(f: => Unit): Unit = { + test(testName) { + try { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + f + } + if (!supportsDynamicOverwrite) { + fail("Expected failure from test, because the table doesn't support dynamic overwrites") + } + } catch { + case a: AnalysisException if !supportsDynamicOverwrite => + assert(a.getMessage.contains("Table does not support dynamic overwrite")) + } + } + } + + if (includeSQLOnlyTests) { + test("InsertInto: when the table doesn't exist") { + val t1 = s"${catalogAndNamespace}tbl" + val t2 = s"${catalogAndNamespace}tbl2" + withTableAndData(t1) { _ => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val e = intercept[AnalysisException] { + sql(s"INSERT INTO $t2 VALUES (2L, 'dummy')") + } + assert(e.getMessage.contains(t2)) + assert(e.getMessage.contains("Table not found")) + } + } + + test("InsertInto: append to partitioned table - static clause") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 PARTITION (id = 23) SELECT data FROM $view") + verifyTable(t1, sql(s"SELECT 23, data FROM $view")) + } + } + + test("InsertInto: static PARTITION clause fails with non-partition column") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (data)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (id=1) SELECT data FROM $view") + } + + verifyTable(t1, spark.emptyDataFrame) + assert(exc.getMessage.contains( + "PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("id")) + } + } + + test("InsertInto: dynamic PARTITION clause fails with non-partition column") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (data) SELECT * FROM $view") + } + + verifyTable(t1, spark.emptyDataFrame) + assert(exc.getMessage.contains( + "PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("data")) + } + } + + test("InsertInto: overwrite - dynamic clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a"), + (2, "b"), + (3, "c")).toDF()) + } + } + } + + dynamicOverwriteTest("InsertInto: overwrite - dynamic clause - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a"), + (2, "b"), + (3, "c"), + (4, "keep")).toDF("id", "data")) + } + } + + test("InsertInto: overwrite - missing clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a"), + (2, "b"), + (3, "c")).toDF("id", "data")) + } + } + } + + dynamicOverwriteTest("InsertInto: overwrite - missing clause - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a"), + (2, "b"), + (3, "c"), + (4, "keep")).toDF("id", "data")) + } + } + + test("InsertInto: overwrite - static clause") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) " + + s"USING $v2Format PARTITIONED BY (p1)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 23), + (2, "b", 23), + (3, "c", 23), + (4, "keep", 2)).toDF("id", "data", "p1")) + } + } + + test("InsertInto: overwrite - mixed clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2)).toDF("id", "data", "p")) + } + } + } + + test("InsertInto: overwrite - mixed clause reordered - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2)).toDF("id", "data", "p")) + } + } + } + + test("InsertInto: overwrite - implicit dynamic partition - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2)).toDF("id", "data", "p")) + } + } + } + + dynamicOverwriteTest("InsertInto: overwrite - mixed clause - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) + } + } + + dynamicOverwriteTest("InsertInto: overwrite - mixed clause reordered - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) + } + } + + dynamicOverwriteTest("InsertInto: overwrite - implicit dynamic partition - dynamic mode") { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM $view") + verifyTable(t1, Seq( + (1, "a", 2), + (2, "b", 2), + (3, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) + } + } + + test("InsertInto: overwrite - multiple static partitions - dynamic mode") { + // Since all partitions are provided statically, this should be supported by everyone + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = s"${catalogAndNamespace}tbl" + withTableAndData(t1) { view => + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + + s"USING $v2Format PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") + verifyTable(t1, Seq( + (2, "a", 2), + (2, "b", 2), + (2, "c", 2), + (4, "keep", 2)).toDF("id", "data", "p")) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala index 60e2443d09..846eba2806 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala @@ -24,8 +24,9 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation} import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase import org.apache.spark.sql.sources.v2.writer.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} @@ -68,13 +69,25 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before } class V1WriteFallbackSessionCatalogSuite - extends SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] { + extends InsertIntoTests(supportsDynamicOverwrite = false, includeSQLOnlyTests = true) + with SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] { + override protected val v2Format = classOf[InMemoryV1Provider].getName override protected val catalogClassName: String = classOf[V1FallbackTableCatalog].getName + override protected val catalogAndNamespace: String = "" override protected def verifyTable(tableName: String, expected: DataFrame): Unit = { checkAnswer(InMemoryV1Provider.getTableData(spark, s"default.$tableName"), expected) } + + protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = { + val tmpView = "tmp_view" + withTempView(tmpView) { + insert.createOrReplaceTempView(tmpView) + val overwrite = if (mode == SaveMode.Overwrite) "OVERWRITE" else "INTO" + sql(s"INSERT $overwrite TABLE $tableName SELECT * FROM $tmpView") + } + } } class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV1Fallback] {