[SPARK-28572][SQL] Simple analyzer checks for v2 table creation code paths

## What changes were proposed in this pull request?

Adds checks around:
 - The existence of transforms in the table schema (even in nested fields)
 - Duplications of transforms
 - Case sensitivity checks around column names
in the V2 table creation code paths.

## How was this patch tested?

Unit tests.

Closes #25305 from brkyvz/v2CreateTable.

Authored-by: Burak Yavuz <brkyvz@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Burak Yavuz 2019-08-09 12:04:28 +08:00 committed by Wenchen Fan
parent 2580c1bfe2
commit c80430f5c9
8 changed files with 452 additions and 25 deletions

View file

@ -59,10 +59,18 @@ private[sql] object LogicalExpressions {
def hours(column: String): HoursTransform = HoursTransform(reference(column))
}
/**
* Allows Spark to rewrite the given references of the transform during analysis.
*/
sealed trait RewritableTransform extends Transform {
/** Creates a copy of this transform with the new analyzed references. */
def withReferences(newReferences: Seq[NamedReference]): Transform
}
/**
* Base class for simple transforms of a single column.
*/
private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform {
private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends RewritableTransform {
def reference: NamedReference = ref
@ -73,18 +81,24 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends T
override def describe: String = name + "(" + reference.describe + ")"
override def toString: String = describe
protected def withNewRef(ref: NamedReference): Transform
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
assert(newReferences.length == 1,
s"Tried rewriting a single column transform (${this}) with multiple references.")
withNewRef(newReferences.head)
}
}
private[sql] final case class BucketTransform(
numBuckets: Literal[Int],
columns: Seq[NamedReference]) extends Transform {
columns: Seq[NamedReference]) extends RewritableTransform {
override val name: String = "bucket"
override def references: Array[NamedReference] = {
arguments
.filter(_.isInstanceOf[NamedReference])
.map(_.asInstanceOf[NamedReference])
arguments.collect { case named: NamedReference => named }
}
override def arguments: Array[Expression] = numBuckets +: columns.toArray
@ -92,6 +106,10 @@ private[sql] final case class BucketTransform(
override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"
override def toString: String = describe
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
}
private[sql] object BucketTransform {
@ -112,9 +130,7 @@ private[sql] final case class ApplyTransform(
override def arguments: Array[Expression] = args.toArray
override def references: Array[NamedReference] = {
arguments
.filter(_.isInstanceOf[NamedReference])
.map(_.asInstanceOf[NamedReference])
arguments.collect { case named: NamedReference => named }
}
override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
@ -143,7 +159,7 @@ private object Ref {
/**
* Convenience extractor for any Transform.
*/
private object NamedTransform {
private[sql] object NamedTransform {
def unapply(transform: Transform): Some[(String, Seq[Expression])] = {
Some((transform.name, transform.arguments))
}
@ -153,6 +169,7 @@ private[sql] final case class IdentityTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "identity"
override def describe: String = ref.describe
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
}
private[sql] object IdentityTransform {
@ -167,6 +184,7 @@ private[sql] object IdentityTransform {
private[sql] final case class YearsTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "years"
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
}
private[sql] object YearsTransform {
@ -181,6 +199,7 @@ private[sql] object YearsTransform {
private[sql] final case class MonthsTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "months"
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
}
private[sql] object MonthsTransform {
@ -195,6 +214,7 @@ private[sql] object MonthsTransform {
private[sql] final case class DaysTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "days"
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
}
private[sql] object DaysTransform {
@ -209,6 +229,7 @@ private[sql] object DaysTransform {
private[sql] final case class HoursTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "hours"
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
}
private[sql] object HoursTransform {

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
@ -299,10 +300,10 @@ trait CheckAnalysis extends PredicateHelper {
}
}
case CreateTableAsSelect(_, _, partitioning, query, _, _, _) =>
val references = partitioning.flatMap(_.references).toSet
case create: V2CreateTablePlan =>
val references = create.partitioning.flatMap(_.references).toSet
val badReferences = references.map(_.fieldNames).flatMap { column =>
query.schema.findNestedField(column) match {
create.tableSchema.findNestedField(column) match {
case Some(_) =>
None
case _ =>

View file

@ -418,7 +418,11 @@ case class CreateV2Table(
tableSchema: StructType,
partitioning: Seq[Transform],
properties: Map[String, String],
ignoreIfExists: Boolean) extends Command
ignoreIfExists: Boolean) extends Command with V2CreateTablePlan {
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
this.copy(partitioning = rewritten)
}
}
/**
* Create a new table from a select query with a v2 catalog.
@ -430,8 +434,9 @@ case class CreateTableAsSelect(
query: LogicalPlan,
properties: Map[String, String],
writeOptions: Map[String, String],
ignoreIfExists: Boolean) extends Command {
ignoreIfExists: Boolean) extends Command with V2CreateTablePlan {
override def tableSchema: StructType = query.schema
override def children: Seq[LogicalPlan] = Seq(query)
override lazy val resolved: Boolean = childrenResolved && {
@ -440,6 +445,10 @@ case class CreateTableAsSelect(
val references = partitioning.flatMap(_.references).toSet
references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
}
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
this.copy(partitioning = rewritten)
}
}
/**
@ -456,7 +465,11 @@ case class ReplaceTable(
tableSchema: StructType,
partitioning: Seq[Transform],
properties: Map[String, String],
orCreate: Boolean) extends Command
orCreate: Boolean) extends Command with V2CreateTablePlan {
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
this.copy(partitioning = rewritten)
}
}
/**
* Replaces a table from a select query with a v2 catalog.
@ -471,8 +484,9 @@ case class ReplaceTableAsSelect(
query: LogicalPlan,
properties: Map[String, String],
writeOptions: Map[String, String],
orCreate: Boolean) extends Command {
orCreate: Boolean) extends Command with V2CreateTablePlan {
override def tableSchema: StructType = query.schema
override def children: Seq[LogicalPlan] = Seq(query)
override lazy val resolved: Boolean = {
@ -481,6 +495,10 @@ case class ReplaceTableAsSelect(
val references = partitioning.flatMap(_.references).toSet
references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
}
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
this.copy(partitioning = rewritten)
}
}
/**
@ -1201,3 +1219,16 @@ case class Deduplicate(
override def output: Seq[Attribute] = child.output
}
/** A trait used for logical plan nodes that create or replace V2 table definitions. */
trait V2CreateTablePlan extends LogicalPlan {
def tableName: Identifier
def partitioning: Seq[Transform]
def tableSchema: StructType
/**
* Creates a copy of this node with the new partitoning transforms. This method is used to
* rewrite the partition transforms normalized according to the table schema.
*/
def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan
}

View file

@ -17,9 +17,12 @@
package org.apache.spark.sql.util
import java.util.Locale
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog.v2.expressions._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
/**
@ -88,4 +91,154 @@ private[spark] object SchemaUtils {
s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}")
}
}
/**
* Returns all column names in this schema as a flat list. For example, a schema like:
* | - a
* | | - 1
* | | - 2
* | - b
* | - c
* | | - nest
* | | - 3
* will get flattened to: "a", "a.1", "a.2", "b", "c", "c.nest", "c.nest.3"
*/
def explodeNestedFieldNames(schema: StructType): Seq[String] = {
def explode(schema: StructType): Seq[Seq[String]] = {
def recurseIntoComplexTypes(complexType: DataType): Seq[Seq[String]] = {
complexType match {
case s: StructType => explode(s)
case a: ArrayType => recurseIntoComplexTypes(a.elementType)
case m: MapType =>
recurseIntoComplexTypes(m.keyType).map(Seq("key") ++ _) ++
recurseIntoComplexTypes(m.valueType).map(Seq("value") ++ _)
case _ => Nil
}
}
schema.flatMap {
case StructField(name, s: StructType, _, _) =>
Seq(Seq(name)) ++ explode(s).map(nested => Seq(name) ++ nested)
case StructField(name, a: ArrayType, _, _) =>
Seq(Seq(name)) ++ recurseIntoComplexTypes(a).map(nested => Seq(name) ++ nested)
case StructField(name, m: MapType, _, _) =>
Seq(Seq(name)) ++ recurseIntoComplexTypes(m).map(nested => Seq(name) ++ nested)
case f => Seq(f.name) :: Nil
}
}
explode(schema).map(UnresolvedAttribute.apply(_).name)
}
/**
* Checks if the partitioning transforms are being duplicated or not. Throws an exception if
* duplication exists.
*
* @param transforms the schema to check for duplicates
* @param checkType contextual information around the check, used in an exception message
* @param isCaseSensitive Whether to be case sensitive when comparing column names
*/
def checkTransformDuplication(
transforms: Seq[Transform],
checkType: String,
isCaseSensitive: Boolean): Unit = {
val extractedTransforms = transforms.map {
case b: BucketTransform =>
val colNames = b.columns.map(c => UnresolvedAttribute(c.fieldNames()).name)
// We need to check that we're not duplicating columns within our bucketing transform
checkColumnNameDuplication(colNames, "in the bucket definition", isCaseSensitive)
b.name -> colNames
case NamedTransform(transformName, refs) =>
val fieldNameParts =
refs.collect { case FieldReference(parts) => UnresolvedAttribute(parts).name }
// We could also check that we're not duplicating column names here as well if
// fieldNameParts.length > 1, but we're specifically not, because certain transforms can
// be defined where this is a legitimate use case.
transformName -> fieldNameParts
}
val normalizedTransforms = if (isCaseSensitive) {
extractedTransforms
} else {
extractedTransforms.map(t => t._1 -> t._2.map(_.toLowerCase(Locale.ROOT)))
}
if (normalizedTransforms.distinct.length != normalizedTransforms.length) {
val duplicateColumns = normalizedTransforms.groupBy(identity).collect {
case (x, ys) if ys.length > 1 => s"${x._2.mkString(".")}"
}
throw new AnalysisException(
s"Found duplicate column(s) $checkType: ${duplicateColumns.mkString(", ")}")
}
}
/**
* Returns the given column's ordinal within the given `schema`. The length of the returned
* position will be as long as how nested the column is.
*
* @param column The column to search for in the given struct. If the length of `column` is
* greater than 1, we expect to enter a nested field.
* @param schema The current struct we are looking at.
* @param resolver The resolver to find the column.
*/
def findColumnPosition(
column: Seq[String],
schema: StructType,
resolver: Resolver): Seq[Int] = {
def find(column: Seq[String], schema: StructType, stack: Seq[String]): Seq[Int] = {
if (column.isEmpty) return Nil
val thisCol = column.head
lazy val columnPath = UnresolvedAttribute(stack :+ thisCol).name
val pos = schema.indexWhere(f => resolver(f.name, thisCol))
if (pos == -1) {
throw new IndexOutOfBoundsException(columnPath)
}
val children = schema(pos).dataType match {
case s: StructType =>
find(column.tail, s, stack :+ thisCol)
case ArrayType(s: StructType, _) =>
find(column.tail, s, stack :+ thisCol)
case o =>
if (column.length > 1) {
throw new AnalysisException(
s"""Expected $columnPath to be a nested data type, but found $o. Was looking for the
|index of ${UnresolvedAttribute(column).name} in a nested field
""".stripMargin)
}
Nil
}
Seq(pos) ++ children
}
try {
find(column, schema, Nil)
} catch {
case i: IndexOutOfBoundsException =>
throw new AnalysisException(
s"Couldn't find column ${i.getMessage} in:\n${schema.treeString}")
case e: AnalysisException =>
throw new AnalysisException(e.getMessage + s":\n${schema.treeString}")
}
}
/**
* Gets the name of the column in the given position.
*/
def getColumnName(position: Seq[Int], schema: StructType): Seq[String] = {
val topLevel = schema(position.head)
val field = position.tail.foldLeft(Seq(topLevel.name) -> topLevel) {
case (nameAndField, pos) =>
nameAndField._2.dataType match {
case s: StructType =>
val nowField = s(pos)
(nameAndField._1 :+ nowField.name) -> nowField
case ArrayType(s: StructType, _) =>
val nowField = s(pos)
(nameAndField._1 :+ nowField.name) -> nowField
case _ =>
throw new AnalysisException(
s"The positions provided ($pos) cannot be resolved in\n${schema.treeString}.")
}
}
field._1
}
}

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog}
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation}
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, ReplaceTable, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement}
@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSour
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.TableProvider
import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils
case class DataSourceResolution(
conf: SQLConf,
@ -123,7 +124,7 @@ case class DataSourceResolution(
case replace: ReplaceTableStatement =>
// the provider was not a v1 source, convert to a v2 plan
val CatalogObjectIdentifier(maybeCatalog, identifier) = replace.tableName
val catalog = maybeCatalog.orElse(defaultCatalog)
val catalog = maybeCatalog.orElse(sessionCatalog)
.getOrElse(throw new AnalysisException(
s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
.asTableCatalog
@ -132,7 +133,7 @@ case class DataSourceResolution(
case rtas: ReplaceTableAsSelectStatement =>
// the provider was not a v1 source, convert to a v2 plan
val CatalogObjectIdentifier(maybeCatalog, identifier) = rtas.tableName
val catalog = maybeCatalog.orElse(defaultCatalog)
val catalog = maybeCatalog.orElse(sessionCatalog)
.getOrElse(throw new AnalysisException(
s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
.asTableCatalog

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, RewritableTransform}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
@ -29,7 +30,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.types.{ArrayType, AtomicType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils
/**
@ -236,6 +237,46 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
c.copy(tableDesc = normalizedTable.copy(schema = reorderedSchema))
}
case create: V2CreateTablePlan =>
val schema = create.tableSchema
val partitioning = create.partitioning
val identifier = create.tableName
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
// Check that columns are not duplicated in the schema
val flattenedSchema = SchemaUtils.explodeNestedFieldNames(schema)
SchemaUtils.checkColumnNameDuplication(
flattenedSchema,
s"in the table definition of $identifier",
isCaseSensitive)
// Check that columns are not duplicated in the partitioning statement
SchemaUtils.checkTransformDuplication(
partitioning, "in the partitioning", isCaseSensitive)
if (schema.isEmpty) {
if (partitioning.nonEmpty) {
throw new AnalysisException("It is not allowed to specify partitioning when the " +
"table schema is not defined.")
}
create
} else {
// Resolve and normalize partition columns as necessary
val resolver = sparkSession.sessionState.conf.resolver
val normalizedPartitions = partitioning.map {
case transform: RewritableTransform =>
val rewritten = transform.references().map { ref =>
// Throws an exception if the reference cannot be resolved
val position = SchemaUtils.findColumnPosition(ref.fieldNames(), schema, resolver)
FieldReference(SchemaUtils.getColumnName(position, schema))
}
transform.withReferences(rewritten)
case other => other
}
create.withPartitioning(normalizedPartitions)
}
}
private def fallBackV2ToV1(cls: Class[_]): Class[_] = cls.newInstance match {

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.v2
import java.util
import java.util.Locale
import scala.collection.JavaConverters._
import scala.collection.mutable
@ -90,10 +89,11 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog {
val location = Option(properties.get("location"))
val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap)
.copy(locationUri = location.map(CatalogUtils.stringToURI))
val tableType = if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED
val tableDesc = CatalogTable(
identifier = ident.asTableIdentifier,
tableType = CatalogTableType.MANAGED,
tableType = tableType,
storage = storage,
schema = schema,
provider = Some(provider),
@ -252,4 +252,3 @@ private[sql] object V2SessionCatalog {
(identityCols, bucketSpec)
}
}

View file

@ -23,19 +23,22 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalog.v2.Identifier
import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, Metadata, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
private val orc2 = classOf[OrcDataSourceV2].getName
private val v2Source = classOf[FakeV2Provider].getName
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName)
@ -1696,4 +1699,181 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn
}
}
}
test("tableCreation: partition column case insensitive resolution") {
val testCatalog = spark.catalog("testcat").asTableCatalog
val sessionCatalog = spark.catalog("session").asTableCatalog
def checkPartitioning(cat: TableCatalog, partition: String): Unit = {
val table = cat.loadTable(Identifier.of(Array.empty, "tbl"))
val partitions = table.partitioning().map(_.references())
assert(partitions.length === 1)
val fieldNames = partitions.flatMap(_.map(_.fieldNames()))
assert(fieldNames === Array(Array(partition)))
}
sql(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
checkPartitioning(sessionCatalog, "a")
sql(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
checkPartitioning(testCatalog, "a")
sql(s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
checkPartitioning(sessionCatalog, "b")
sql(s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
checkPartitioning(testCatalog, "b")
}
test("tableCreation: partition column case sensitive resolution") {
def checkFailure(statement: String): Unit = {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val e = intercept[AnalysisException] {
sql(statement)
}
assert(e.getMessage.contains("Couldn't find column"))
}
}
checkFailure(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
checkFailure(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
checkFailure(
s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
checkFailure(
s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
}
test("tableCreation: duplicate column names in the table definition") {
val errorMsg = "Found duplicate column(s) in the table definition of `t`"
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
testCreateAnalysisError(
s"CREATE TABLE t ($c0 INT, $c1 INT) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE t ($c0 INT, $c1 INT) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source",
errorMsg
)
}
}
}
test("tableCreation: duplicate nested column names in the table definition") {
val errorMsg = "Found duplicate column(s) in the table definition of `t`"
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
testCreateAnalysisError(
s"CREATE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
errorMsg
)
}
}
}
test("tableCreation: bucket column names not in table definition") {
val errorMsg = "Couldn't find column c in"
testCreateAnalysisError(
s"CREATE TABLE tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source " +
"CLUSTERED BY (c) INTO 4 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source " +
"CLUSTERED BY (c) INTO 4 BUCKETS",
errorMsg
)
}
test("tableCreation: column repeated in partition columns") {
val errorMsg = "Found duplicate column(s) in the partitioning"
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
testCreateAnalysisError(
s"CREATE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
errorMsg
)
testCreateAnalysisError(
s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
errorMsg
)
}
}
}
test("tableCreation: column repeated in bucket columns") {
val errorMsg = "Found duplicate column(s) in the bucket definition"
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
testCreateAnalysisError(
s"CREATE TABLE t ($c0 INT) USING $v2Source " +
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source " +
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source " +
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
errorMsg
)
testCreateAnalysisError(
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source " +
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
errorMsg
)
}
}
}
private def testCreateAnalysisError(sqlStatement: String, expectedError: String): Unit = {
val errMsg = intercept[AnalysisException] {
sql(sqlStatement)
}.getMessage
assert(errMsg.contains(expectedError))
}
}
/** Used as a V2 DataSource for V2SessionCatalog DDL */
class FakeV2Provider extends TableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new UnsupportedOperationException("Unnecessary for DDL tests")
}
}