From c80430f5c9189b37ac1209db0453dbd9bb5c767e Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 9 Aug 2019 12:04:28 +0800 Subject: [PATCH] [SPARK-28572][SQL] Simple analyzer checks for v2 table creation code paths ## What changes were proposed in this pull request? Adds checks around: - The existence of transforms in the table schema (even in nested fields) - Duplications of transforms - Case sensitivity checks around column names in the V2 table creation code paths. ## How was this patch tested? Unit tests. Closes #25305 from brkyvz/v2CreateTable. Authored-by: Burak Yavuz Signed-off-by: Wenchen Fan --- .../catalog/v2/expressions/expressions.scala | 39 +++- .../sql/catalyst/analysis/CheckAnalysis.scala | 7 +- .../plans/logical/basicLogicalOperators.scala | 39 +++- .../apache/spark/sql/util/SchemaUtils.scala | 155 ++++++++++++++- .../datasources/DataSourceResolution.scala | 7 +- .../sql/execution/datasources/rules.scala | 43 ++++- .../datasources/v2/V2SessionCatalog.scala | 5 +- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 182 +++++++++++++++++- 8 files changed, 452 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala index ea5fc05dd5..bceea147dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -59,10 +59,18 @@ private[sql] object LogicalExpressions { def hours(column: String): HoursTransform = HoursTransform(reference(column)) } +/** + * Allows Spark to rewrite the given references of the transform during analysis. + */ +sealed trait RewritableTransform extends Transform { + /** Creates a copy of this transform with the new analyzed references. */ + def withReferences(newReferences: Seq[NamedReference]): Transform +} + /** * Base class for simple transforms of a single column. */ -private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform { +private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends RewritableTransform { def reference: NamedReference = ref @@ -73,18 +81,24 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends T override def describe: String = name + "(" + reference.describe + ")" override def toString: String = describe + + protected def withNewRef(ref: NamedReference): Transform + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + assert(newReferences.length == 1, + s"Tried rewriting a single column transform (${this}) with multiple references.") + withNewRef(newReferences.head) + } } private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference]) extends Transform { + columns: Seq[NamedReference]) extends RewritableTransform { override val name: String = "bucket" override def references: Array[NamedReference] = { - arguments - .filter(_.isInstanceOf[NamedReference]) - .map(_.asInstanceOf[NamedReference]) + arguments.collect { case named: NamedReference => named } } override def arguments: Array[Expression] = numBuckets +: columns.toArray @@ -92,6 +106,10 @@ private[sql] final case class BucketTransform( override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" override def toString: String = describe + + override def withReferences(newReferences: Seq[NamedReference]): Transform = { + this.copy(columns = newReferences) + } } private[sql] object BucketTransform { @@ -112,9 +130,7 @@ private[sql] final case class ApplyTransform( override def arguments: Array[Expression] = args.toArray override def references: Array[NamedReference] = { - arguments - .filter(_.isInstanceOf[NamedReference]) - .map(_.asInstanceOf[NamedReference]) + arguments.collect { case named: NamedReference => named } } override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" @@ -143,7 +159,7 @@ private object Ref { /** * Convenience extractor for any Transform. */ -private object NamedTransform { +private[sql] object NamedTransform { def unapply(transform: Transform): Some[(String, Seq[Expression])] = { Some((transform.name, transform.arguments)) } @@ -153,6 +169,7 @@ private[sql] final case class IdentityTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "identity" override def describe: String = ref.describe + override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref) } private[sql] object IdentityTransform { @@ -167,6 +184,7 @@ private[sql] object IdentityTransform { private[sql] final case class YearsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "years" + override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref) } private[sql] object YearsTransform { @@ -181,6 +199,7 @@ private[sql] object YearsTransform { private[sql] final case class MonthsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "months" + override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref) } private[sql] object MonthsTransform { @@ -195,6 +214,7 @@ private[sql] object MonthsTransform { private[sql] final case class DaysTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "days" + override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref) } private[sql] object DaysTransform { @@ -209,6 +229,7 @@ private[sql] object DaysTransform { private[sql] final case class HoursTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "hours" + override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref) } private[sql] object HoursTransform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ae19d02e44..519c558d12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils /** * Throws user facing errors when passed invalid queries that fail to analyze. @@ -299,10 +300,10 @@ trait CheckAnalysis extends PredicateHelper { } } - case CreateTableAsSelect(_, _, partitioning, query, _, _, _) => - val references = partitioning.flatMap(_.references).toSet + case create: V2CreateTablePlan => + val references = create.partitioning.flatMap(_.references).toSet val badReferences = references.map(_.fieldNames).flatMap { column => - query.schema.findNestedField(column) match { + create.tableSchema.findNestedField(column) match { case Some(_) => None case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6f33944fc1..d9c370af47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -418,7 +418,11 @@ case class CreateV2Table( tableSchema: StructType, partitioning: Seq[Transform], properties: Map[String, String], - ignoreIfExists: Boolean) extends Command + ignoreIfExists: Boolean) extends Command with V2CreateTablePlan { + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { + this.copy(partitioning = rewritten) + } +} /** * Create a new table from a select query with a v2 catalog. @@ -430,8 +434,9 @@ case class CreateTableAsSelect( query: LogicalPlan, properties: Map[String, String], writeOptions: Map[String, String], - ignoreIfExists: Boolean) extends Command { + ignoreIfExists: Boolean) extends Command with V2CreateTablePlan { + override def tableSchema: StructType = query.schema override def children: Seq[LogicalPlan] = Seq(query) override lazy val resolved: Boolean = childrenResolved && { @@ -440,6 +445,10 @@ case class CreateTableAsSelect( val references = partitioning.flatMap(_.references).toSet references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) } + + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { + this.copy(partitioning = rewritten) + } } /** @@ -456,7 +465,11 @@ case class ReplaceTable( tableSchema: StructType, partitioning: Seq[Transform], properties: Map[String, String], - orCreate: Boolean) extends Command + orCreate: Boolean) extends Command with V2CreateTablePlan { + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { + this.copy(partitioning = rewritten) + } +} /** * Replaces a table from a select query with a v2 catalog. @@ -471,8 +484,9 @@ case class ReplaceTableAsSelect( query: LogicalPlan, properties: Map[String, String], writeOptions: Map[String, String], - orCreate: Boolean) extends Command { + orCreate: Boolean) extends Command with V2CreateTablePlan { + override def tableSchema: StructType = query.schema override def children: Seq[LogicalPlan] = Seq(query) override lazy val resolved: Boolean = { @@ -481,6 +495,10 @@ case class ReplaceTableAsSelect( val references = partitioning.flatMap(_.references).toSet references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) } + + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { + this.copy(partitioning = rewritten) + } } /** @@ -1201,3 +1219,16 @@ case class Deduplicate( override def output: Seq[Attribute] = child.output } + +/** A trait used for logical plan nodes that create or replace V2 table definitions. */ +trait V2CreateTablePlan extends LogicalPlan { + def tableName: Identifier + def partitioning: Seq[Transform] + def tableSchema: StructType + + /** + * Creates a copy of this node with the new partitoning transforms. This method is used to + * rewrite the partition transforms normalized according to the table schema. + */ + def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 052014ab86..d15440632f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.util +import java.util.Locale + import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.expressions._ import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} /** @@ -88,4 +91,154 @@ private[spark] object SchemaUtils { s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}") } } + + /** + * Returns all column names in this schema as a flat list. For example, a schema like: + * | - a + * | | - 1 + * | | - 2 + * | - b + * | - c + * | | - nest + * | | - 3 + * will get flattened to: "a", "a.1", "a.2", "b", "c", "c.nest", "c.nest.3" + */ + def explodeNestedFieldNames(schema: StructType): Seq[String] = { + def explode(schema: StructType): Seq[Seq[String]] = { + def recurseIntoComplexTypes(complexType: DataType): Seq[Seq[String]] = { + complexType match { + case s: StructType => explode(s) + case a: ArrayType => recurseIntoComplexTypes(a.elementType) + case m: MapType => + recurseIntoComplexTypes(m.keyType).map(Seq("key") ++ _) ++ + recurseIntoComplexTypes(m.valueType).map(Seq("value") ++ _) + case _ => Nil + } + } + + schema.flatMap { + case StructField(name, s: StructType, _, _) => + Seq(Seq(name)) ++ explode(s).map(nested => Seq(name) ++ nested) + case StructField(name, a: ArrayType, _, _) => + Seq(Seq(name)) ++ recurseIntoComplexTypes(a).map(nested => Seq(name) ++ nested) + case StructField(name, m: MapType, _, _) => + Seq(Seq(name)) ++ recurseIntoComplexTypes(m).map(nested => Seq(name) ++ nested) + case f => Seq(f.name) :: Nil + } + } + + explode(schema).map(UnresolvedAttribute.apply(_).name) + } + + /** + * Checks if the partitioning transforms are being duplicated or not. Throws an exception if + * duplication exists. + * + * @param transforms the schema to check for duplicates + * @param checkType contextual information around the check, used in an exception message + * @param isCaseSensitive Whether to be case sensitive when comparing column names + */ + def checkTransformDuplication( + transforms: Seq[Transform], + checkType: String, + isCaseSensitive: Boolean): Unit = { + val extractedTransforms = transforms.map { + case b: BucketTransform => + val colNames = b.columns.map(c => UnresolvedAttribute(c.fieldNames()).name) + // We need to check that we're not duplicating columns within our bucketing transform + checkColumnNameDuplication(colNames, "in the bucket definition", isCaseSensitive) + b.name -> colNames + case NamedTransform(transformName, refs) => + val fieldNameParts = + refs.collect { case FieldReference(parts) => UnresolvedAttribute(parts).name } + // We could also check that we're not duplicating column names here as well if + // fieldNameParts.length > 1, but we're specifically not, because certain transforms can + // be defined where this is a legitimate use case. + transformName -> fieldNameParts + } + val normalizedTransforms = if (isCaseSensitive) { + extractedTransforms + } else { + extractedTransforms.map(t => t._1 -> t._2.map(_.toLowerCase(Locale.ROOT))) + } + + if (normalizedTransforms.distinct.length != normalizedTransforms.length) { + val duplicateColumns = normalizedTransforms.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => s"${x._2.mkString(".")}" + } + throw new AnalysisException( + s"Found duplicate column(s) $checkType: ${duplicateColumns.mkString(", ")}") + } + } + + /** + * Returns the given column's ordinal within the given `schema`. The length of the returned + * position will be as long as how nested the column is. + * + * @param column The column to search for in the given struct. If the length of `column` is + * greater than 1, we expect to enter a nested field. + * @param schema The current struct we are looking at. + * @param resolver The resolver to find the column. + */ + def findColumnPosition( + column: Seq[String], + schema: StructType, + resolver: Resolver): Seq[Int] = { + def find(column: Seq[String], schema: StructType, stack: Seq[String]): Seq[Int] = { + if (column.isEmpty) return Nil + val thisCol = column.head + lazy val columnPath = UnresolvedAttribute(stack :+ thisCol).name + val pos = schema.indexWhere(f => resolver(f.name, thisCol)) + if (pos == -1) { + throw new IndexOutOfBoundsException(columnPath) + } + val children = schema(pos).dataType match { + case s: StructType => + find(column.tail, s, stack :+ thisCol) + case ArrayType(s: StructType, _) => + find(column.tail, s, stack :+ thisCol) + case o => + if (column.length > 1) { + throw new AnalysisException( + s"""Expected $columnPath to be a nested data type, but found $o. Was looking for the + |index of ${UnresolvedAttribute(column).name} in a nested field + """.stripMargin) + } + Nil + } + Seq(pos) ++ children + } + + try { + find(column, schema, Nil) + } catch { + case i: IndexOutOfBoundsException => + throw new AnalysisException( + s"Couldn't find column ${i.getMessage} in:\n${schema.treeString}") + case e: AnalysisException => + throw new AnalysisException(e.getMessage + s":\n${schema.treeString}") + } + } + + /** + * Gets the name of the column in the given position. + */ + def getColumnName(position: Seq[Int], schema: StructType): Seq[String] = { + val topLevel = schema(position.head) + val field = position.tail.foldLeft(Seq(topLevel.name) -> topLevel) { + case (nameAndField, pos) => + nameAndField._2.dataType match { + case s: StructType => + val nowField = s(pos) + (nameAndField._1 :+ nowField.name) -> nowField + case ArrayType(s: StructType, _) => + val nowField = s(pos) + (nameAndField._1 :+ nowField.name) -> nowField + case _ => + throw new AnalysisException( + s"The positions provided ($pos) cannot be resolved in\n${schema.treeString}.") + } + } + field._1 + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index f17b31da57..a150a049f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, ReplaceTable, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSour import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.util.SchemaUtils case class DataSourceResolution( conf: SQLConf, @@ -123,7 +124,7 @@ case class DataSourceResolution( case replace: ReplaceTableStatement => // the provider was not a v1 source, convert to a v2 plan val CatalogObjectIdentifier(maybeCatalog, identifier) = replace.tableName - val catalog = maybeCatalog.orElse(defaultCatalog) + val catalog = maybeCatalog.orElse(sessionCatalog) .getOrElse(throw new AnalysisException( s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) .asTableCatalog @@ -132,7 +133,7 @@ case class DataSourceResolution( case rtas: ReplaceTableAsSelectStatement => // the provider was not a v1 source, convert to a v2 plan val CatalogObjectIdentifier(maybeCatalog, identifier) = rtas.tableName - val catalog = maybeCatalog.orElse(defaultCatalog) + val catalog = maybeCatalog.orElse(sessionCatalog) .getOrElse(throw new AnalysisException( s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) .asTableCatalog diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index e8951bc8e7..a0a90503ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} @@ -29,7 +30,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation -import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.types.{ArrayType, AtomicType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils /** @@ -236,6 +237,46 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy(tableDesc = normalizedTable.copy(schema = reorderedSchema)) } + + case create: V2CreateTablePlan => + val schema = create.tableSchema + val partitioning = create.partitioning + val identifier = create.tableName + val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + // Check that columns are not duplicated in the schema + val flattenedSchema = SchemaUtils.explodeNestedFieldNames(schema) + SchemaUtils.checkColumnNameDuplication( + flattenedSchema, + s"in the table definition of $identifier", + isCaseSensitive) + + // Check that columns are not duplicated in the partitioning statement + SchemaUtils.checkTransformDuplication( + partitioning, "in the partitioning", isCaseSensitive) + + if (schema.isEmpty) { + if (partitioning.nonEmpty) { + throw new AnalysisException("It is not allowed to specify partitioning when the " + + "table schema is not defined.") + } + + create + } else { + // Resolve and normalize partition columns as necessary + val resolver = sparkSession.sessionState.conf.resolver + val normalizedPartitions = partitioning.map { + case transform: RewritableTransform => + val rewritten = transform.references().map { ref => + // Throws an exception if the reference cannot be resolved + val position = SchemaUtils.findColumnPosition(ref.fieldNames(), schema, resolver) + FieldReference(SchemaUtils.getColumnName(position, schema)) + } + transform.withReferences(rewritten) + case other => other + } + + create.withPartitioning(normalizedPartitions) + } } private def fallBackV2ToV1(cls: Class[_]): Class[_] = cls.newInstance match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 4cd0346b57..a3b8f28fc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util -import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable @@ -90,10 +89,11 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { val location = Option(properties.get("location")) val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) .copy(locationUri = location.map(CatalogUtils.stringToURI)) + val tableType = if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED val tableDesc = CatalogTable( identifier = ident.asTableIdentifier, - tableType = CatalogTableType.MANAGED, + tableType = tableType, storage = storage, schema = schema, provider = Some(provider), @@ -252,4 +252,3 @@ private[sql] object V2SessionCatalog { (identityCols, bucketSpec) } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index d95021077f..9ae51d577b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -23,19 +23,22 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, Metadata, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ private val orc2 = classOf[OrcDataSourceV2].getName + private val v2Source = classOf[FakeV2Provider].getName before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) @@ -1696,4 +1699,181 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } } } + + test("tableCreation: partition column case insensitive resolution") { + val testCatalog = spark.catalog("testcat").asTableCatalog + val sessionCatalog = spark.catalog("session").asTableCatalog + + def checkPartitioning(cat: TableCatalog, partition: String): Unit = { + val table = cat.loadTable(Identifier.of(Array.empty, "tbl")) + val partitions = table.partitioning().map(_.references()) + assert(partitions.length === 1) + val fieldNames = partitions.flatMap(_.map(_.fieldNames())) + assert(fieldNames === Array(Array(partition))) + } + + sql(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)") + checkPartitioning(sessionCatalog, "a") + sql(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)") + checkPartitioning(testCatalog, "a") + sql(s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)") + checkPartitioning(sessionCatalog, "b") + sql(s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)") + checkPartitioning(testCatalog, "b") + } + + test("tableCreation: partition column case sensitive resolution") { + def checkFailure(statement: String): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val e = intercept[AnalysisException] { + sql(statement) + } + assert(e.getMessage.contains("Couldn't find column")) + } + } + + checkFailure(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)") + checkFailure(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)") + checkFailure( + s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)") + checkFailure( + s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)") + } + + test("tableCreation: duplicate column names in the table definition") { + val errorMsg = "Found duplicate column(s) in the table definition of `t`" + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + testCreateAnalysisError( + s"CREATE TABLE t ($c0 INT, $c1 INT) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE t ($c0 INT, $c1 INT) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source", + errorMsg + ) + } + } + } + + test("tableCreation: duplicate nested column names in the table definition") { + val errorMsg = "Found duplicate column(s) in the table definition of `t`" + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + testCreateAnalysisError( + s"CREATE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source", + errorMsg + ) + } + } + } + + test("tableCreation: bucket column names not in table definition") { + val errorMsg = "Couldn't find column c in" + testCreateAnalysisError( + s"CREATE TABLE tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source " + + "CLUSTERED BY (c) INTO 4 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source " + + "CLUSTERED BY (c) INTO 4 BUCKETS", + errorMsg + ) + } + + test("tableCreation: column repeated in partition columns") { + val errorMsg = "Found duplicate column(s) in the partitioning" + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + testCreateAnalysisError( + s"CREATE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", + errorMsg + ) + testCreateAnalysisError( + s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)", + errorMsg + ) + } + } + } + + test("tableCreation: column repeated in bucket columns") { + val errorMsg = "Found duplicate column(s) in the bucket definition" + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + testCreateAnalysisError( + s"CREATE TABLE t ($c0 INT) USING $v2Source " + + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source " + + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source " + + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", + errorMsg + ) + testCreateAnalysisError( + s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source " + + s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS", + errorMsg + ) + } + } + } + + private def testCreateAnalysisError(sqlStatement: String, expectedError: String): Unit = { + val errMsg = intercept[AnalysisException] { + sql(sqlStatement) + }.getMessage + assert(errMsg.contains(expectedError)) + } +} + + +/** Used as a V2 DataSource for V2SessionCatalog DDL */ +class FakeV2Provider extends TableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + throw new UnsupportedOperationException("Unnecessary for DDL tests") + } }