[SPARK-28837][SQL] CTAS/RTAS should use nullable schema

<!--
Thanks for sending a pull request!  Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
  2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
  4. Be sure to keep the PR description updated to reflect all changes.
  5. Please write your PR title to summarize what this PR proposes.
  6. If possible, provide a concise example to reproduce the issue for a faster review.
-->

### What changes were proposed in this pull request?
<!--
Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
  1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
  2. If you fix some SQL features, you can provide some references of other DBMSes.
  3. If there is design documentation, please add the link.
  4. If there is a discussion in the mailing list, please add the link.
-->
When running CTAS/RTAS, use the nullable schema of the input query to create the table.

### Why are the changes needed?
<!--
Please clarify why the changes are needed. For instance,
  1. If you propose a new API, clarify the use case for a new API.
  2. If you fix a bug, you can clarify why it is a bug.
-->
It's very likely to run CTAS/RTAS with non-nullable input query, e.g. `CREATE TABLE t AS SELECT 1`. However, it's surprising to users if they can't write null to this table later. Non-nullable is kind of a constraint of the column and should be specified by users explicitly.

For reference, Postgres also use nullable schema for CTAS:
```
> create table t1(i int not null);

> insert into t1 values (1);

> create table t2 as select i from t1;

> \d+ t1;
 Column |  Type   | Collation | Nullable | Default | Storage | Stats target | Description
--------+---------+-----------+----------+---------+---------+--------------+-------------
 i      | integer |           | not null |         | plain   |              |

> \d+ t2;
 Column |  Type   | Collation | Nullable | Default | Storage | Stats target | Description
--------+---------+-----------+----------+---------+---------+--------------+-------------
 i      | integer |           |          |         | plain   |              |

```

File source V1 has the same behavior.

### Does this PR introduce any user-facing change?
<!--
If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
If no, write 'No'.
-->
Yes, after this PR CTAS/RTAS creates tables with nullable schema, then users can insert null values later.

### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why it was difficult to add.
-->
new test

Closes #25536 from cloud-fan/ctas.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Wenchen Fan 2019-08-22 09:49:18 +08:00
parent 97b046f06f
commit ed3ea6734c
2 changed files with 43 additions and 15 deletions

View file

@ -81,11 +81,12 @@ case class CreateTableAsSelectExec(
}
Utils.tryWithSafeFinallyAndFailureCallbacks({
val schema = query.schema.asNullable
catalog.createTable(
ident, query.schema, partitioning.toArray, properties.asJava) match {
ident, schema, partitioning.toArray, properties.asJava) match {
case table: SupportsWrite =>
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withInputDataSchema(schema)
.withQueryId(UUID.randomUUID().toString)
writeBuilder match {
@ -132,7 +133,7 @@ case class AtomicCreateTableAsSelectExec(
throw new TableAlreadyExistsException(ident)
}
val stagedTable = catalog.stageCreate(
ident, query.schema, partitioning.toArray, properties.asJava)
ident, query.schema.asNullable, partitioning.toArray, properties.asJava)
writeToStagedTable(stagedTable, writeOptions, ident)
}
}
@ -173,13 +174,14 @@ case class ReplaceTableAsSelectExec(
} else if (!orCreate) {
throw new CannotReplaceMissingTableException(ident)
}
val schema = query.schema.asNullable
val createdTable = catalog.createTable(
ident, query.schema, partitioning.toArray, properties.asJava)
ident, schema, partitioning.toArray, properties.asJava)
Utils.tryWithSafeFinallyAndFailureCallbacks({
createdTable match {
case table: SupportsWrite =>
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withInputDataSchema(schema)
.withQueryId(UUID.randomUUID().toString)
writeBuilder match {
@ -221,13 +223,14 @@ case class AtomicReplaceTableAsSelectExec(
orCreate: Boolean) extends AtomicTableWriteExec {
override protected def doExecute(): RDD[InternalRow] = {
val schema = query.schema.asNullable
val staged = if (orCreate) {
catalog.stageCreateOrReplace(
ident, query.schema, partitioning.toArray, properties.asJava)
ident, schema, partitioning.toArray, properties.asJava)
} else if (catalog.tableExists(ident)) {
try {
catalog.stageReplace(
ident, query.schema, partitioning.toArray, properties.asJava)
ident, schema, partitioning.toArray, properties.asJava)
} catch {
case e: NoSuchTableException =>
throw new CannotReplaceMissingTableException(ident, Some(e))

View file

@ -232,7 +232,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> "foo").asJava)
assert(table.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -258,8 +258,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(replacedTable.name == identifier)
assert(replacedTable.partitioning.isEmpty)
assert(replacedTable.properties == Map("provider" -> "foo").asJava)
assert(replacedTable.schema == new StructType()
.add("id", LongType, nullable = false))
assert(replacedTable.schema == new StructType().add("id", LongType))
val rdd = spark.sparkContext.parallelize(replacedTable.asInstanceOf[InMemoryTable].rows)
checkAnswer(
@ -391,7 +390,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> orc2).asJava)
assert(table.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -408,7 +407,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> "foo").asJava)
assert(table.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -428,7 +427,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table2.partitioning.isEmpty)
assert(table2.properties == Map("provider" -> "foo").asJava)
assert(table2.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -446,7 +445,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> "foo").asJava)
assert(table.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -477,7 +476,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> "foo").asJava)
assert(table.schema == new StructType()
.add("id", LongType, nullable = false)
.add("id", LongType)
.add("data", StringType))
val rdd = sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
@ -500,6 +499,32 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSparkSession with Before
assert(t.isInstanceOf[UnresolvedTable], "V1 table wasn't returned as an unresolved table")
}
test("CreateTableAsSelect: nullable schema") {
val basicCatalog = catalog("testcat").asTableCatalog
val atomicCatalog = catalog("testcat_atomic").asTableCatalog
val basicIdentifier = "testcat.table_name"
val atomicIdentifier = "testcat_atomic.table_name"
Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach {
case (catalog, identifier) =>
spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT 1 i")
val table = catalog.loadTable(Identifier.of(Array(), "table_name"))
assert(table.name == identifier)
assert(table.partitioning.isEmpty)
assert(table.properties == Map("provider" -> "foo").asJava)
assert(table.schema == new StructType().add("i", "int"))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Row(1))
sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)")
val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null)))
}
}
test("DropTable: basic") {
val tableName = "testcat.ns1.ns2.tbl"
val ident = Identifier.of(Array("ns1", "ns2"), "tbl")