From ed3ea6734c8459b4062789ef53fe792143a2011c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Aug 2019 09:49:18 +0800 Subject: [PATCH] [SPARK-28837][SQL] CTAS/RTAS should use nullable schema ### What changes were proposed in this pull request? When running CTAS/RTAS, use the nullable schema of the input query to create the table. ### Why are the changes needed? It's very likely to run CTAS/RTAS with non-nullable input query, e.g. `CREATE TABLE t AS SELECT 1`. However, it's surprising to users if they can't write null to this table later. Non-nullable is kind of a constraint of the column and should be specified by users explicitly. For reference, Postgres also use nullable schema for CTAS: ``` > create table t1(i int not null); > insert into t1 values (1); > create table t2 as select i from t1; > \d+ t1; Column | Type | Collation | Nullable | Default | Storage | Stats target | Description --------+---------+-----------+----------+---------+---------+--------------+------------- i | integer | | not null | | plain | | > \d+ t2; Column | Type | Collation | Nullable | Default | Storage | Stats target | Description --------+---------+-----------+----------+---------+---------+--------------+------------- i | integer | | | | plain | | ``` File source V1 has the same behavior. ### Does this PR introduce any user-facing change? Yes, after this PR CTAS/RTAS creates tables with nullable schema, then users can insert null values later. ### How was this patch tested? new test Closes #25536 from cloud-fan/ctas. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../v2/WriteToDataSourceV2Exec.scala | 17 ++++---- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 41 +++++++++++++++---- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 39269a3f43..0131d72ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -81,11 +81,12 @@ case class CreateTableAsSelectExec( } Utils.tryWithSafeFinallyAndFailureCallbacks({ + val schema = query.schema.asNullable catalog.createTable( - ident, query.schema, partitioning.toArray, properties.asJava) match { + ident, schema, partitioning.toArray, properties.asJava) match { case table: SupportsWrite => val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) + .withInputDataSchema(schema) .withQueryId(UUID.randomUUID().toString) writeBuilder match { @@ -132,7 +133,7 @@ case class AtomicCreateTableAsSelectExec( throw new TableAlreadyExistsException(ident) } val stagedTable = catalog.stageCreate( - ident, query.schema, partitioning.toArray, properties.asJava) + ident, query.schema.asNullable, partitioning.toArray, properties.asJava) writeToStagedTable(stagedTable, writeOptions, ident) } } @@ -173,13 +174,14 @@ case class ReplaceTableAsSelectExec( } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) } + val schema = query.schema.asNullable val createdTable = catalog.createTable( - ident, query.schema, partitioning.toArray, properties.asJava) + ident, schema, partitioning.toArray, properties.asJava) Utils.tryWithSafeFinallyAndFailureCallbacks({ createdTable match { case table: SupportsWrite => val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) + .withInputDataSchema(schema) .withQueryId(UUID.randomUUID().toString) writeBuilder match { @@ -221,13 +223,14 @@ case class AtomicReplaceTableAsSelectExec( orCreate: Boolean) extends AtomicTableWriteExec { override protected def doExecute(): RDD[InternalRow] = { + val schema = query.schema.asNullable val staged = if (orCreate) { catalog.stageCreateOrReplace( - ident, query.schema, partitioning.toArray, properties.asJava) + ident, schema, partitioning.toArray, properties.asJava) } else if (catalog.tableExists(ident)) { try { catalog.stageReplace( - ident, query.schema, partitioning.toArray, properties.asJava) + ident, schema, partitioning.toArray, properties.asJava) } catch { case e: NoSuchTableException => throw new CannotReplaceMissingTableException(ident, Some(e)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 08082d88c0..732b3e1d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -232,7 +232,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> "foo").asJava) assert(table.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -258,8 +258,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(replacedTable.name == identifier) assert(replacedTable.partitioning.isEmpty) assert(replacedTable.properties == Map("provider" -> "foo").asJava) - assert(replacedTable.schema == new StructType() - .add("id", LongType, nullable = false)) + assert(replacedTable.schema == new StructType().add("id", LongType)) val rdd = spark.sparkContext.parallelize(replacedTable.asInstanceOf[InMemoryTable].rows) checkAnswer( @@ -391,7 +390,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> orc2).asJava) assert(table.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -408,7 +407,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> "foo").asJava) assert(table.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -428,7 +427,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table2.partitioning.isEmpty) assert(table2.properties == Map("provider" -> "foo").asJava) assert(table2.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -446,7 +445,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> "foo").asJava) assert(table.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -477,7 +476,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> "foo").asJava) assert(table.schema == new StructType() - .add("id", LongType, nullable = false) + .add("id", LongType) .add("data", StringType)) val rdd = sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) @@ -500,6 +499,32 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before assert(t.isInstanceOf[UnresolvedTable], "V1 table wasn't returned as an unresolved table") } + test("CreateTableAsSelect: nullable schema") { + val basicCatalog = catalog("testcat").asTableCatalog + val atomicCatalog = catalog("testcat_atomic").asTableCatalog + val basicIdentifier = "testcat.table_name" + val atomicIdentifier = "testcat_atomic.table_name" + + Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach { + case (catalog, identifier) => + spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT 1 i") + + val table = catalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == identifier) + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("i", "int")) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Row(1)) + + sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)") + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null))) + } + } + test("DropTable: basic") { val tableName = "testcat.ns1.ns2.tbl" val ident = Identifier.of(Array("ns1", "ns2"), "tbl")