[SPARK-19279][SQL][FOLLOW-UP] Infer Schema for Hive Serde Tables

### What changes were proposed in this pull request?
`table.schema` is always not empty for partitioned tables, because `table.schema` also contains the partitioned columns, even if the original table does not have any column. This PR is to fix the issue.

### How was this patch tested?
Added a test case

Author: gatorsmile <gatorsmile@gmail.com>

Closes #16848 from gatorsmile/inferHiveSerdeSchema.
This commit is contained in:
gatorsmile 2017-02-08 10:11:44 -05:00
parent 0077bfcb93
commit 4d4d0de7f6
3 changed files with 53 additions and 1 deletions

View file

@ -194,6 +194,14 @@ case class CatalogTable(
StructType(partitionFields) StructType(partitionFields)
} }
/**
* schema of this table's data columns
*/
def dataSchema: StructType = {
val dataFields = schema.dropRight(partitionColumnNames.length)
StructType(dataFields)
}
/** Return the database this table was specified to belong to, assuming it exists. */ /** Return the database this table was specified to belong to, assuming it exists. */
def database: String = identifier.database.getOrElse { def database: String = identifier.database.getOrElse {
throw new AnalysisException(s"table $identifier did not specify database") throw new AnalysisException(s"table $identifier did not specify database")

View file

@ -580,7 +580,7 @@ private[spark] object HiveUtils extends Logging {
* CatalogTable. * CatalogTable.
*/ */
def inferSchema(table: CatalogTable): CatalogTable = { def inferSchema(table: CatalogTable): CatalogTable = {
if (DDLUtils.isDatasourceTable(table) || table.schema.nonEmpty) { if (DDLUtils.isDatasourceTable(table) || table.dataSchema.nonEmpty) {
table table
} else { } else {
val hiveTable = toHiveTable(table) val hiveTable = toHiveTable(table)

View file

@ -27,6 +27,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.command.CreateTableCommand
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.hive.HiveExternalCatalog._ import org.apache.spark.sql.hive.HiveExternalCatalog._
import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.client.HiveClient
@ -1308,6 +1309,49 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
} }
} }
test("Infer schema for Hive serde tables") {
val tableName = "tab1"
val avroSchema =
"""{
| "name": "test_record",
| "type": "record",
| "fields": [ {
| "name": "f0",
| "type": "int"
| }]
|}
""".stripMargin
Seq(true, false).foreach { isPartitioned =>
withTable(tableName) {
val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else ""
// Creates the (non-)partitioned Avro table
val plan = sql(
s"""
|CREATE TABLE $tableName
|$partitionClause
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
|STORED AS
| INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat'
| OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat'
|TBLPROPERTIES ('avro.schema.literal' = '$avroSchema')
""".stripMargin
).queryExecution.analyzed
assert(plan.isInstanceOf[CreateTableCommand] &&
plan.asInstanceOf[CreateTableCommand].table.dataSchema.nonEmpty)
if (isPartitioned) {
sql(s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1")
checkAnswer(spark.table(tableName), Row(1, "a"))
} else {
sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1")
checkAnswer(spark.table(tableName), Row(1))
}
}
}
}
private def withDebugMode(f: => Unit): Unit = { private def withDebugMode(f: => Unit): Unit = {
val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE)
try { try {