[SPARK-28572][SQL] Simple analyzer checks for v2 table creation code paths
## What changes were proposed in this pull request? Adds checks around: - The existence of transforms in the table schema (even in nested fields) - Duplications of transforms - Case sensitivity checks around column names in the V2 table creation code paths. ## How was this patch tested? Unit tests. Closes #25305 from brkyvz/v2CreateTable. Authored-by: Burak Yavuz <brkyvz@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
2580c1bfe2
commit
c80430f5c9
|
@ -59,10 +59,18 @@ private[sql] object LogicalExpressions {
|
|||
def hours(column: String): HoursTransform = HoursTransform(reference(column))
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows Spark to rewrite the given references of the transform during analysis.
|
||||
*/
|
||||
sealed trait RewritableTransform extends Transform {
|
||||
/** Creates a copy of this transform with the new analyzed references. */
|
||||
def withReferences(newReferences: Seq[NamedReference]): Transform
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for simple transforms of a single column.
|
||||
*/
|
||||
private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform {
|
||||
private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends RewritableTransform {
|
||||
|
||||
def reference: NamedReference = ref
|
||||
|
||||
|
@ -73,18 +81,24 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends T
|
|||
override def describe: String = name + "(" + reference.describe + ")"
|
||||
|
||||
override def toString: String = describe
|
||||
|
||||
protected def withNewRef(ref: NamedReference): Transform
|
||||
|
||||
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
|
||||
assert(newReferences.length == 1,
|
||||
s"Tried rewriting a single column transform (${this}) with multiple references.")
|
||||
withNewRef(newReferences.head)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class BucketTransform(
|
||||
numBuckets: Literal[Int],
|
||||
columns: Seq[NamedReference]) extends Transform {
|
||||
columns: Seq[NamedReference]) extends RewritableTransform {
|
||||
|
||||
override val name: String = "bucket"
|
||||
|
||||
override def references: Array[NamedReference] = {
|
||||
arguments
|
||||
.filter(_.isInstanceOf[NamedReference])
|
||||
.map(_.asInstanceOf[NamedReference])
|
||||
arguments.collect { case named: NamedReference => named }
|
||||
}
|
||||
|
||||
override def arguments: Array[Expression] = numBuckets +: columns.toArray
|
||||
|
@ -92,6 +106,10 @@ private[sql] final case class BucketTransform(
|
|||
override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"
|
||||
|
||||
override def toString: String = describe
|
||||
|
||||
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
|
||||
this.copy(columns = newReferences)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object BucketTransform {
|
||||
|
@ -112,9 +130,7 @@ private[sql] final case class ApplyTransform(
|
|||
override def arguments: Array[Expression] = args.toArray
|
||||
|
||||
override def references: Array[NamedReference] = {
|
||||
arguments
|
||||
.filter(_.isInstanceOf[NamedReference])
|
||||
.map(_.asInstanceOf[NamedReference])
|
||||
arguments.collect { case named: NamedReference => named }
|
||||
}
|
||||
|
||||
override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
|
||||
|
@ -143,7 +159,7 @@ private object Ref {
|
|||
/**
|
||||
* Convenience extractor for any Transform.
|
||||
*/
|
||||
private object NamedTransform {
|
||||
private[sql] object NamedTransform {
|
||||
def unapply(transform: Transform): Some[(String, Seq[Expression])] = {
|
||||
Some((transform.name, transform.arguments))
|
||||
}
|
||||
|
@ -153,6 +169,7 @@ private[sql] final case class IdentityTransform(
|
|||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "identity"
|
||||
override def describe: String = ref.describe
|
||||
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
|
||||
}
|
||||
|
||||
private[sql] object IdentityTransform {
|
||||
|
@ -167,6 +184,7 @@ private[sql] object IdentityTransform {
|
|||
private[sql] final case class YearsTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "years"
|
||||
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
|
||||
}
|
||||
|
||||
private[sql] object YearsTransform {
|
||||
|
@ -181,6 +199,7 @@ private[sql] object YearsTransform {
|
|||
private[sql] final case class MonthsTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "months"
|
||||
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
|
||||
}
|
||||
|
||||
private[sql] object MonthsTransform {
|
||||
|
@ -195,6 +214,7 @@ private[sql] object MonthsTransform {
|
|||
private[sql] final case class DaysTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "days"
|
||||
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
|
||||
}
|
||||
|
||||
private[sql] object DaysTransform {
|
||||
|
@ -209,6 +229,7 @@ private[sql] object DaysTransform {
|
|||
private[sql] final case class HoursTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "hours"
|
||||
override protected def withNewRef(ref: NamedReference): Transform = this.copy(ref)
|
||||
}
|
||||
|
||||
private[sql] object HoursTransform {
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.SchemaUtils
|
||||
|
||||
/**
|
||||
* Throws user facing errors when passed invalid queries that fail to analyze.
|
||||
|
@ -299,10 +300,10 @@ trait CheckAnalysis extends PredicateHelper {
|
|||
}
|
||||
}
|
||||
|
||||
case CreateTableAsSelect(_, _, partitioning, query, _, _, _) =>
|
||||
val references = partitioning.flatMap(_.references).toSet
|
||||
case create: V2CreateTablePlan =>
|
||||
val references = create.partitioning.flatMap(_.references).toSet
|
||||
val badReferences = references.map(_.fieldNames).flatMap { column =>
|
||||
query.schema.findNestedField(column) match {
|
||||
create.tableSchema.findNestedField(column) match {
|
||||
case Some(_) =>
|
||||
None
|
||||
case _ =>
|
||||
|
|
|
@ -418,7 +418,11 @@ case class CreateV2Table(
|
|||
tableSchema: StructType,
|
||||
partitioning: Seq[Transform],
|
||||
properties: Map[String, String],
|
||||
ignoreIfExists: Boolean) extends Command
|
||||
ignoreIfExists: Boolean) extends Command with V2CreateTablePlan {
|
||||
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
|
||||
this.copy(partitioning = rewritten)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new table from a select query with a v2 catalog.
|
||||
|
@ -430,8 +434,9 @@ case class CreateTableAsSelect(
|
|||
query: LogicalPlan,
|
||||
properties: Map[String, String],
|
||||
writeOptions: Map[String, String],
|
||||
ignoreIfExists: Boolean) extends Command {
|
||||
ignoreIfExists: Boolean) extends Command with V2CreateTablePlan {
|
||||
|
||||
override def tableSchema: StructType = query.schema
|
||||
override def children: Seq[LogicalPlan] = Seq(query)
|
||||
|
||||
override lazy val resolved: Boolean = childrenResolved && {
|
||||
|
@ -440,6 +445,10 @@ case class CreateTableAsSelect(
|
|||
val references = partitioning.flatMap(_.references).toSet
|
||||
references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
|
||||
}
|
||||
|
||||
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
|
||||
this.copy(partitioning = rewritten)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -456,7 +465,11 @@ case class ReplaceTable(
|
|||
tableSchema: StructType,
|
||||
partitioning: Seq[Transform],
|
||||
properties: Map[String, String],
|
||||
orCreate: Boolean) extends Command
|
||||
orCreate: Boolean) extends Command with V2CreateTablePlan {
|
||||
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
|
||||
this.copy(partitioning = rewritten)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces a table from a select query with a v2 catalog.
|
||||
|
@ -471,8 +484,9 @@ case class ReplaceTableAsSelect(
|
|||
query: LogicalPlan,
|
||||
properties: Map[String, String],
|
||||
writeOptions: Map[String, String],
|
||||
orCreate: Boolean) extends Command {
|
||||
orCreate: Boolean) extends Command with V2CreateTablePlan {
|
||||
|
||||
override def tableSchema: StructType = query.schema
|
||||
override def children: Seq[LogicalPlan] = Seq(query)
|
||||
|
||||
override lazy val resolved: Boolean = {
|
||||
|
@ -481,6 +495,10 @@ case class ReplaceTableAsSelect(
|
|||
val references = partitioning.flatMap(_.references).toSet
|
||||
references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
|
||||
}
|
||||
|
||||
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
|
||||
this.copy(partitioning = rewritten)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1201,3 +1219,16 @@ case class Deduplicate(
|
|||
|
||||
override def output: Seq[Attribute] = child.output
|
||||
}
|
||||
|
||||
/** A trait used for logical plan nodes that create or replace V2 table definitions. */
|
||||
trait V2CreateTablePlan extends LogicalPlan {
|
||||
def tableName: Identifier
|
||||
def partitioning: Seq[Transform]
|
||||
def tableSchema: StructType
|
||||
|
||||
/**
|
||||
* Creates a copy of this node with the new partitoning transforms. This method is used to
|
||||
* rewrite the partition transforms normalized according to the table schema.
|
||||
*/
|
||||
def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan
|
||||
}
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.util
|
||||
|
||||
import java.util.Locale
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalog.v2.expressions._
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
|
||||
|
||||
|
||||
/**
|
||||
|
@ -88,4 +91,154 @@ private[spark] object SchemaUtils {
|
|||
s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns all column names in this schema as a flat list. For example, a schema like:
|
||||
* | - a
|
||||
* | | - 1
|
||||
* | | - 2
|
||||
* | - b
|
||||
* | - c
|
||||
* | | - nest
|
||||
* | | - 3
|
||||
* will get flattened to: "a", "a.1", "a.2", "b", "c", "c.nest", "c.nest.3"
|
||||
*/
|
||||
def explodeNestedFieldNames(schema: StructType): Seq[String] = {
|
||||
def explode(schema: StructType): Seq[Seq[String]] = {
|
||||
def recurseIntoComplexTypes(complexType: DataType): Seq[Seq[String]] = {
|
||||
complexType match {
|
||||
case s: StructType => explode(s)
|
||||
case a: ArrayType => recurseIntoComplexTypes(a.elementType)
|
||||
case m: MapType =>
|
||||
recurseIntoComplexTypes(m.keyType).map(Seq("key") ++ _) ++
|
||||
recurseIntoComplexTypes(m.valueType).map(Seq("value") ++ _)
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
schema.flatMap {
|
||||
case StructField(name, s: StructType, _, _) =>
|
||||
Seq(Seq(name)) ++ explode(s).map(nested => Seq(name) ++ nested)
|
||||
case StructField(name, a: ArrayType, _, _) =>
|
||||
Seq(Seq(name)) ++ recurseIntoComplexTypes(a).map(nested => Seq(name) ++ nested)
|
||||
case StructField(name, m: MapType, _, _) =>
|
||||
Seq(Seq(name)) ++ recurseIntoComplexTypes(m).map(nested => Seq(name) ++ nested)
|
||||
case f => Seq(f.name) :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
explode(schema).map(UnresolvedAttribute.apply(_).name)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the partitioning transforms are being duplicated or not. Throws an exception if
|
||||
* duplication exists.
|
||||
*
|
||||
* @param transforms the schema to check for duplicates
|
||||
* @param checkType contextual information around the check, used in an exception message
|
||||
* @param isCaseSensitive Whether to be case sensitive when comparing column names
|
||||
*/
|
||||
def checkTransformDuplication(
|
||||
transforms: Seq[Transform],
|
||||
checkType: String,
|
||||
isCaseSensitive: Boolean): Unit = {
|
||||
val extractedTransforms = transforms.map {
|
||||
case b: BucketTransform =>
|
||||
val colNames = b.columns.map(c => UnresolvedAttribute(c.fieldNames()).name)
|
||||
// We need to check that we're not duplicating columns within our bucketing transform
|
||||
checkColumnNameDuplication(colNames, "in the bucket definition", isCaseSensitive)
|
||||
b.name -> colNames
|
||||
case NamedTransform(transformName, refs) =>
|
||||
val fieldNameParts =
|
||||
refs.collect { case FieldReference(parts) => UnresolvedAttribute(parts).name }
|
||||
// We could also check that we're not duplicating column names here as well if
|
||||
// fieldNameParts.length > 1, but we're specifically not, because certain transforms can
|
||||
// be defined where this is a legitimate use case.
|
||||
transformName -> fieldNameParts
|
||||
}
|
||||
val normalizedTransforms = if (isCaseSensitive) {
|
||||
extractedTransforms
|
||||
} else {
|
||||
extractedTransforms.map(t => t._1 -> t._2.map(_.toLowerCase(Locale.ROOT)))
|
||||
}
|
||||
|
||||
if (normalizedTransforms.distinct.length != normalizedTransforms.length) {
|
||||
val duplicateColumns = normalizedTransforms.groupBy(identity).collect {
|
||||
case (x, ys) if ys.length > 1 => s"${x._2.mkString(".")}"
|
||||
}
|
||||
throw new AnalysisException(
|
||||
s"Found duplicate column(s) $checkType: ${duplicateColumns.mkString(", ")}")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the given column's ordinal within the given `schema`. The length of the returned
|
||||
* position will be as long as how nested the column is.
|
||||
*
|
||||
* @param column The column to search for in the given struct. If the length of `column` is
|
||||
* greater than 1, we expect to enter a nested field.
|
||||
* @param schema The current struct we are looking at.
|
||||
* @param resolver The resolver to find the column.
|
||||
*/
|
||||
def findColumnPosition(
|
||||
column: Seq[String],
|
||||
schema: StructType,
|
||||
resolver: Resolver): Seq[Int] = {
|
||||
def find(column: Seq[String], schema: StructType, stack: Seq[String]): Seq[Int] = {
|
||||
if (column.isEmpty) return Nil
|
||||
val thisCol = column.head
|
||||
lazy val columnPath = UnresolvedAttribute(stack :+ thisCol).name
|
||||
val pos = schema.indexWhere(f => resolver(f.name, thisCol))
|
||||
if (pos == -1) {
|
||||
throw new IndexOutOfBoundsException(columnPath)
|
||||
}
|
||||
val children = schema(pos).dataType match {
|
||||
case s: StructType =>
|
||||
find(column.tail, s, stack :+ thisCol)
|
||||
case ArrayType(s: StructType, _) =>
|
||||
find(column.tail, s, stack :+ thisCol)
|
||||
case o =>
|
||||
if (column.length > 1) {
|
||||
throw new AnalysisException(
|
||||
s"""Expected $columnPath to be a nested data type, but found $o. Was looking for the
|
||||
|index of ${UnresolvedAttribute(column).name} in a nested field
|
||||
""".stripMargin)
|
||||
}
|
||||
Nil
|
||||
}
|
||||
Seq(pos) ++ children
|
||||
}
|
||||
|
||||
try {
|
||||
find(column, schema, Nil)
|
||||
} catch {
|
||||
case i: IndexOutOfBoundsException =>
|
||||
throw new AnalysisException(
|
||||
s"Couldn't find column ${i.getMessage} in:\n${schema.treeString}")
|
||||
case e: AnalysisException =>
|
||||
throw new AnalysisException(e.getMessage + s":\n${schema.treeString}")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the name of the column in the given position.
|
||||
*/
|
||||
def getColumnName(position: Seq[Int], schema: StructType): Seq[String] = {
|
||||
val topLevel = schema(position.head)
|
||||
val field = position.tail.foldLeft(Seq(topLevel.name) -> topLevel) {
|
||||
case (nameAndField, pos) =>
|
||||
nameAndField._2.dataType match {
|
||||
case s: StructType =>
|
||||
val nowField = s(pos)
|
||||
(nameAndField._1 :+ nowField.name) -> nowField
|
||||
case ArrayType(s: StructType, _) =>
|
||||
val nowField = s(pos)
|
||||
(nameAndField._1 :+ nowField.name) -> nowField
|
||||
case _ =>
|
||||
throw new AnalysisException(
|
||||
s"The positions provided ($pos) cannot be resolved in\n${schema.treeString}.")
|
||||
}
|
||||
}
|
||||
field._1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
|
|||
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog}
|
||||
import org.apache.spark.sql.catalog.v2.expressions.Transform
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.CastSupport
|
||||
import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, ReplaceTable, ReplaceTableAsSelect}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement}
|
||||
|
@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSour
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.v2.TableProvider
|
||||
import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType}
|
||||
import org.apache.spark.sql.util.SchemaUtils
|
||||
|
||||
case class DataSourceResolution(
|
||||
conf: SQLConf,
|
||||
|
@ -123,7 +124,7 @@ case class DataSourceResolution(
|
|||
case replace: ReplaceTableStatement =>
|
||||
// the provider was not a v1 source, convert to a v2 plan
|
||||
val CatalogObjectIdentifier(maybeCatalog, identifier) = replace.tableName
|
||||
val catalog = maybeCatalog.orElse(defaultCatalog)
|
||||
val catalog = maybeCatalog.orElse(sessionCatalog)
|
||||
.getOrElse(throw new AnalysisException(
|
||||
s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
|
||||
.asTableCatalog
|
||||
|
@ -132,7 +133,7 @@ case class DataSourceResolution(
|
|||
case rtas: ReplaceTableAsSelectStatement =>
|
||||
// the provider was not a v1 source, convert to a v2 plan
|
||||
val CatalogObjectIdentifier(maybeCatalog, identifier) = rtas.tableName
|
||||
val catalog = maybeCatalog.orElse(defaultCatalog)
|
||||
val catalog = maybeCatalog.orElse(sessionCatalog)
|
||||
.getOrElse(throw new AnalysisException(
|
||||
s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
|
||||
.asTableCatalog
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
|
|||
import java.util.Locale
|
||||
|
||||
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
|
||||
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, RewritableTransform}
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
|
||||
|
@ -29,7 +30,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
|
|||
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources.InsertableRelation
|
||||
import org.apache.spark.sql.types.{AtomicType, StructType}
|
||||
import org.apache.spark.sql.types.{ArrayType, AtomicType, StructField, StructType}
|
||||
import org.apache.spark.sql.util.SchemaUtils
|
||||
|
||||
/**
|
||||
|
@ -236,6 +237,46 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
|
|||
|
||||
c.copy(tableDesc = normalizedTable.copy(schema = reorderedSchema))
|
||||
}
|
||||
|
||||
case create: V2CreateTablePlan =>
|
||||
val schema = create.tableSchema
|
||||
val partitioning = create.partitioning
|
||||
val identifier = create.tableName
|
||||
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
|
||||
// Check that columns are not duplicated in the schema
|
||||
val flattenedSchema = SchemaUtils.explodeNestedFieldNames(schema)
|
||||
SchemaUtils.checkColumnNameDuplication(
|
||||
flattenedSchema,
|
||||
s"in the table definition of $identifier",
|
||||
isCaseSensitive)
|
||||
|
||||
// Check that columns are not duplicated in the partitioning statement
|
||||
SchemaUtils.checkTransformDuplication(
|
||||
partitioning, "in the partitioning", isCaseSensitive)
|
||||
|
||||
if (schema.isEmpty) {
|
||||
if (partitioning.nonEmpty) {
|
||||
throw new AnalysisException("It is not allowed to specify partitioning when the " +
|
||||
"table schema is not defined.")
|
||||
}
|
||||
|
||||
create
|
||||
} else {
|
||||
// Resolve and normalize partition columns as necessary
|
||||
val resolver = sparkSession.sessionState.conf.resolver
|
||||
val normalizedPartitions = partitioning.map {
|
||||
case transform: RewritableTransform =>
|
||||
val rewritten = transform.references().map { ref =>
|
||||
// Throws an exception if the reference cannot be resolved
|
||||
val position = SchemaUtils.findColumnPosition(ref.fieldNames(), schema, resolver)
|
||||
FieldReference(SchemaUtils.getColumnName(position, schema))
|
||||
}
|
||||
transform.withReferences(rewritten)
|
||||
case other => other
|
||||
}
|
||||
|
||||
create.withPartitioning(normalizedPartitions)
|
||||
}
|
||||
}
|
||||
|
||||
private def fallBackV2ToV1(cls: Class[_]): Class[_] = cls.newInstance match {
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql.execution.datasources.v2
|
||||
|
||||
import java.util
|
||||
import java.util.Locale
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
@ -90,10 +89,11 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog {
|
|||
val location = Option(properties.get("location"))
|
||||
val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap)
|
||||
.copy(locationUri = location.map(CatalogUtils.stringToURI))
|
||||
val tableType = if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED
|
||||
|
||||
val tableDesc = CatalogTable(
|
||||
identifier = ident.asTableIdentifier,
|
||||
tableType = CatalogTableType.MANAGED,
|
||||
tableType = tableType,
|
||||
storage = storage,
|
||||
schema = schema,
|
||||
provider = Some(provider),
|
||||
|
@ -252,4 +252,3 @@ private[sql] object V2SessionCatalog {
|
|||
(identityCols, bucketSpec)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,19 +23,22 @@ import org.scalatest.BeforeAndAfter
|
|||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
|
||||
import org.apache.spark.sql.catalog.v2.Identifier
|
||||
import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
|
||||
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
|
||||
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
|
||||
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, Metadata, StringType, StructField, StructType, TimestampType}
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||
|
||||
class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
|
||||
|
||||
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
|
||||
|
||||
private val orc2 = classOf[OrcDataSourceV2].getName
|
||||
private val v2Source = classOf[FakeV2Provider].getName
|
||||
|
||||
before {
|
||||
spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName)
|
||||
|
@ -1696,4 +1699,181 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tableCreation: partition column case insensitive resolution") {
|
||||
val testCatalog = spark.catalog("testcat").asTableCatalog
|
||||
val sessionCatalog = spark.catalog("session").asTableCatalog
|
||||
|
||||
def checkPartitioning(cat: TableCatalog, partition: String): Unit = {
|
||||
val table = cat.loadTable(Identifier.of(Array.empty, "tbl"))
|
||||
val partitions = table.partitioning().map(_.references())
|
||||
assert(partitions.length === 1)
|
||||
val fieldNames = partitions.flatMap(_.map(_.fieldNames()))
|
||||
assert(fieldNames === Array(Array(partition)))
|
||||
}
|
||||
|
||||
sql(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
|
||||
checkPartitioning(sessionCatalog, "a")
|
||||
sql(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
|
||||
checkPartitioning(testCatalog, "a")
|
||||
sql(s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
|
||||
checkPartitioning(sessionCatalog, "b")
|
||||
sql(s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
|
||||
checkPartitioning(testCatalog, "b")
|
||||
}
|
||||
|
||||
test("tableCreation: partition column case sensitive resolution") {
|
||||
def checkFailure(statement: String): Unit = {
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
|
||||
val e = intercept[AnalysisException] {
|
||||
sql(statement)
|
||||
}
|
||||
assert(e.getMessage.contains("Couldn't find column"))
|
||||
}
|
||||
}
|
||||
|
||||
checkFailure(s"CREATE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
|
||||
checkFailure(s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (A)")
|
||||
checkFailure(
|
||||
s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
|
||||
checkFailure(
|
||||
s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source PARTITIONED BY (B)")
|
||||
}
|
||||
|
||||
test("tableCreation: duplicate column names in the table definition") {
|
||||
val errorMsg = "Found duplicate column(s) in the table definition of `t`"
|
||||
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE t ($c0 INT, $c1 INT) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE t ($c0 INT, $c1 INT) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT, $c1 INT) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tableCreation: duplicate nested column names in the table definition") {
|
||||
val errorMsg = "Found duplicate column(s) in the table definition of `t`"
|
||||
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE testcat.t (d struct<$c0: INT, $c1: INT>) USING $v2Source",
|
||||
errorMsg
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tableCreation: bucket column names not in table definition") {
|
||||
val errorMsg = "Couldn't find column c in"
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE testcat.tbl (a int, b string) USING $v2Source CLUSTERED BY (c) INTO 4 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE tbl (a int, b string) USING $v2Source " +
|
||||
"CLUSTERED BY (c) INTO 4 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE testcat.tbl (a int, b string) USING $v2Source " +
|
||||
"CLUSTERED BY (c) INTO 4 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
}
|
||||
|
||||
test("tableCreation: column repeated in partition columns") {
|
||||
val errorMsg = "Found duplicate column(s) in the partitioning"
|
||||
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source PARTITIONED BY ($c0, $c1)",
|
||||
errorMsg
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("tableCreation: column repeated in bucket columns") {
|
||||
val errorMsg = "Found duplicate column(s) in the bucket definition"
|
||||
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
|
||||
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE t ($c0 INT) USING $v2Source " +
|
||||
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE TABLE testcat.t ($c0 INT) USING $v2Source " +
|
||||
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE t ($c0 INT) USING $v2Source " +
|
||||
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
testCreateAnalysisError(
|
||||
s"CREATE OR REPLACE TABLE testcat.t ($c0 INT) USING $v2Source " +
|
||||
s"CLUSTERED BY ($c0, $c1) INTO 2 BUCKETS",
|
||||
errorMsg
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def testCreateAnalysisError(sqlStatement: String, expectedError: String): Unit = {
|
||||
val errMsg = intercept[AnalysisException] {
|
||||
sql(sqlStatement)
|
||||
}.getMessage
|
||||
assert(errMsg.contains(expectedError))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Used as a V2 DataSource for V2SessionCatalog DDL */
|
||||
class FakeV2Provider extends TableProvider {
|
||||
override def getTable(options: CaseInsensitiveStringMap): Table = {
|
||||
throw new UnsupportedOperationException("Unnecessary for DDL tests")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue