[SPARK-10092] [SQL] Multi-DB support follow up.

https://issues.apache.org/jira/browse/SPARK-10092

This pr is a follow-up one for Multi-DB support. It has the following changes:

* `HiveContext.refreshTable` now accepts `dbName.tableName`.
* `HiveContext.analyze` now accepts `dbName.tableName`.
* `CreateTableUsing`, `CreateTableUsingAsSelect`, `CreateTempTableUsing`, `CreateTempTableUsingAsSelect`, `CreateMetastoreDataSource`, and `CreateMetastoreDataSourceAsSelect` all take `TableIdentifier` instead of the string representation of table name.
* When you call `saveAsTable` with a specified database, the data will be saved to the correct location.
* Explicitly do not allow users to create a temporary with a specified database name (users cannot do it before).
* When we save table to metastore, we also check if db name and table name can be accepted by hive (using `MetaStoreUtils.validateName`).

Author: Yin Huai <yhuai@databricks.com>

Closes #8324 from yhuai/saveAsTableDB.
This commit is contained in:
Yin Huai 2015-08-20 15:30:31 +08:00 committed by Cheng Lian
parent b762f9920f
commit 43e0135421
16 changed files with 400 additions and 96 deletions

View file

@ -25,7 +25,9 @@ private[sql] case class TableIdentifier(table: String, database: Option[String]
def toSeq: Seq[String] = database.toSeq :+ table
override def toString: String = toSeq.map("`" + _ + "`").mkString(".")
override def toString: String = quotedString
def quotedString: String = toSeq.map("`" + _ + "`").mkString(".")
def unquotedString: String = toSeq.mkString(".")
}

View file

@ -23,6 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
@ -55,12 +56,15 @@ trait Catalog {
def refreshTable(tableIdent: TableIdentifier): Unit
// TODO: Refactor it in the work of SPARK-10104
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
// TODO: Refactor it in the work of SPARK-10104
def unregisterTable(tableIdentifier: Seq[String]): Unit
def unregisterAllTables(): Unit
// TODO: Refactor it in the work of SPARK-10104
protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
if (conf.caseSensitiveAnalysis) {
tableIdentifier
@ -69,6 +73,7 @@ trait Catalog {
}
}
// TODO: Refactor it in the work of SPARK-10104
protected def getDbTableName(tableIdent: Seq[String]): String = {
val size = tableIdent.size
if (size <= 2) {
@ -78,9 +83,22 @@ trait Catalog {
}
}
// TODO: Refactor it in the work of SPARK-10104
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
}
/**
* It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]].
* We use this method to check it.
*/
protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = {
if (tableIdentifier.length > 1) {
throw new AnalysisException("Specifying database name or other qualifiers are not allowed " +
"for temporary tables. If the table name has dots (.) in it, please quote the " +
"table name with backticks (`).")
}
}
}
class SimpleCatalog(val conf: CatalystConf) extends Catalog {
@ -89,11 +107,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.put(getDbTableName(tableIdent), plan)
}
override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.remove(getDbTableName(tableIdent))
}
@ -103,6 +123,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
}
override def tableExists(tableIdentifier: Seq[String]): Boolean = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.containsKey(getDbTableName(tableIdent))
}
@ -110,6 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
override def lookupRelation(
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
val tableFullName = getDbTableName(tableIdent)
val table = tables.get(tableFullName)
@ -149,7 +171,13 @@ trait OverrideCatalog extends Catalog {
abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.get(getDBTable(tableIdent)) match {
// A temporary tables only has a single part in the tableIdentifier.
val overriddenTable = if (tableIdentifier.length > 1) {
None: Option[LogicalPlan]
} else {
overrides.get(getDBTable(tableIdent))
}
overriddenTable match {
case Some(_) => true
case None => super.tableExists(tableIdentifier)
}
@ -159,7 +187,12 @@ trait OverrideCatalog extends Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
val tableIdent = processTableIdentifier(tableIdentifier)
val overriddenTable = overrides.get(getDBTable(tableIdent))
// A temporary tables only has a single part in the tableIdentifier.
val overriddenTable = if (tableIdentifier.length > 1) {
None: Option[LogicalPlan]
} else {
overrides.get(getDBTable(tableIdent))
}
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
@ -171,20 +204,8 @@ trait OverrideCatalog extends Catalog {
}
abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = if (conf.caseSensitiveAnalysis) {
databaseName
} else {
if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
}
val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
// We always return all temporary tables.
val temporaryTables = overrides.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq
@ -194,13 +215,19 @@ trait OverrideCatalog extends Catalog {
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.put(getDBTable(tableIdent), plan)
}
override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.remove(getDBTable(tableIdent))
// A temporary tables only has a single part in the tableIdentifier.
// If tableIdentifier has more than one parts, it is not a temporary table
// and we do not need to do anything at here.
if (tableIdentifier.length == 1) {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.remove(getDBTable(tableIdent))
}
}
override def unregisterAllTables(): Unit = {

View file

@ -218,7 +218,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
case _ =>
val cmd =
CreateTableUsingAsSelect(
tableIdent.unquotedString,
tableIdent,
source,
temporary = false,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),

View file

@ -584,9 +584,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema = None,
source,
temporary = false,
@ -594,7 +595,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
allowExisting = false,
managedIfNoPath = false)
executePlan(cmd).toRdd
table(tableName)
table(tableIdent)
}
/**
@ -629,9 +630,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema = Some(schema),
source,
temporary = false,
@ -639,7 +641,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
allowExisting = false,
managedIfNoPath = false)
executePlan(cmd).toRdd
table(tableName)
table(tableIdent)
}
/**
@ -724,7 +726,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
table(new SqlParser().parseTableIdentifier(tableName))
}
private def table(tableIdent: TableIdentifier): DataFrame = {
DataFrame(this, catalog.lookupRelation(tableIdent.toSeq))
}

View file

@ -395,22 +395,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) =>
case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) =>
ExecutedCommand(
CreateTempTableUsing(
tableName, userSpecifiedSchema, provider, opts)) :: Nil
tableIdent, userSpecifiedSchema, provider, opts)) :: Nil
case c: CreateTableUsing if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query)
case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query)
if partitionsCols.nonEmpty =>
sys.error("Cannot create temporary partitioned table.")
case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) =>
case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) =>
val cmd = CreateTempTableUsingAsSelect(
tableName, provider, Array.empty[String], mode, opts, query)
tableIdent, provider, Array.empty[String], mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")

View file

@ -80,9 +80,9 @@ class DDLParser(parseQuery: String => LogicalPlan)
*/
protected lazy val createTable: Parser[LogicalPlan] = {
// TODO: Support database.table.
(CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
(CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~
tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query =>
if (temp.isDefined && allowExisting.isDefined) {
throw new DDLException(
"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
@ -104,7 +104,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
}
val queryPlan = parseQuery(query.get)
CreateTableUsingAsSelect(tableName,
CreateTableUsingAsSelect(tableIdent,
provider,
temp.isDefined,
Array.empty[String],
@ -114,7 +114,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
} else {
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema,
provider,
temp.isDefined,
@ -125,6 +125,12 @@ class DDLParser(parseQuery: String => LogicalPlan)
}
}
// This is the same as tableIdentifier in SqlParser.
protected lazy val tableIdentifier: Parser[TableIdentifier] =
(ident <~ ".").? ~ ident ^^ {
case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName)
}
protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
/*
@ -132,21 +138,15 @@ class DDLParser(parseQuery: String => LogicalPlan)
* This will display all columns of table `avroTable` includes column_name,column_type,comment
*/
protected lazy val describeTable: Parser[LogicalPlan] =
(DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
case e ~ db ~ tbl =>
val tblIdentifier = db match {
case Some(dbName) =>
Seq(dbName, tbl)
case None =>
Seq(tbl)
}
DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined)
(DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ {
case e ~ tableIdent =>
DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined)
}
protected lazy val refreshTable: Parser[LogicalPlan] =
REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ {
case maybeDatabaseName ~ tableName =>
RefreshTable(TableIdentifier(tableName, maybeDatabaseName))
REFRESH ~> TABLE ~> tableIdentifier ^^ {
case tableIndet =>
RefreshTable(tableIndet)
}
protected lazy val options: Parser[Map[String, String]] =

View file

@ -53,7 +53,7 @@ case class DescribeCommand(
* If it is false, an exception will be thrown
*/
case class CreateTableUsing(
tableName: String,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
provider: String,
temporary: Boolean,
@ -71,8 +71,9 @@ case class CreateTableUsing(
* can analyze the logical plan that will be used to populate the table.
* So, [[PreWriteCheck]] can detect cases that are not allowed.
*/
// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104).
case class CreateTableUsingAsSelect(
tableName: String,
tableIdent: TableIdentifier,
provider: String,
temporary: Boolean,
partitionColumns: Array[String],
@ -80,12 +81,10 @@ case class CreateTableUsingAsSelect(
options: Map[String, String],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
// TODO: Override resolved after we support databaseName.
// override lazy val resolved = databaseName != None && childrenResolved
}
case class CreateTempTableUsing(
tableName: String,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]) extends RunnableCommand {
@ -93,14 +92,16 @@ case class CreateTempTableUsing(
def run(sqlContext: SQLContext): Seq[Row] = {
val resolved = ResolvedDataSource(
sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
sqlContext.catalog.registerTable(
tableIdent.toSeq,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)
Seq.empty[Row]
}
}
case class CreateTempTableUsingAsSelect(
tableName: String,
tableIdent: TableIdentifier,
provider: String,
partitionColumns: Array[String],
mode: SaveMode,
@ -110,8 +111,9 @@ case class CreateTempTableUsingAsSelect(
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = DataFrame(sqlContext, query)
val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
sqlContext.catalog.registerTable(
tableIdent.toSeq,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)
Seq.empty[Row]
}

View file

@ -140,12 +140,12 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
case CreateTableUsingAsSelect(tableName, _, _, partitionColumns, mode, _, query) =>
case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) =>
// When the SaveMode is Overwrite, we need to check if the table is an input table of
// the query. If so, we will throw an AnalysisException to let users know it is not allowed.
if (mode == SaveMode.Overwrite && catalog.tableExists(Seq(tableName))) {
if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent.toSeq)) {
// Need to remove SubQuery operator.
EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match {
EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) match {
// Only do the check if the table is a data source table
// (the relation is a BaseRelation).
case l @ LogicalRelation(dest: BaseRelation) =>
@ -155,7 +155,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
if (srcRelations.contains(dest)) {
failAnalysis(
s"Cannot overwrite table $tableName that is also being read from.")
s"Cannot overwrite table $tableIdent that is also being read from.")
} else {
// OK
}

View file

@ -1644,4 +1644,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select count(num) from 1one"), Row(10))
}
}
test("specifying database name for a temporary table is not allowed") {
withTempPath { dir =>
val path = dir.getCanonicalPath
val df =
sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
df
.write
.format("parquet")
.save(path)
val message = intercept[AnalysisException] {
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE db.t
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
}.getMessage
assert(message.contains("Specifying database name or other qualifiers are not allowed"))
// If you use backticks to quote the name of a temporary table having dot in it.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
checkAnswer(sqlContext.table("`db.t`"), df)
}
}
}

View file

@ -43,7 +43,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect}
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand}
@ -189,6 +189,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
// We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options
// into the isolated client loader
val metadataConf = new HiveConf()
val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir")
logInfo("defalt warehouse location is " + defaltWarehouseLocation)
// `configure` goes second to override other settings.
val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure
@ -288,12 +292,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
* @since 1.3.0
*/
def refreshTable(tableName: String): Unit = {
val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase)
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
catalog.refreshTable(tableIdent)
}
protected[hive] def invalidateTable(tableName: String): Unit = {
catalog.invalidateTable(catalog.client.currentDatabase, tableName)
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
catalog.invalidateTable(tableIdent)
}
/**
@ -307,7 +312,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
*/
@Experimental
def analyze(tableName: String) {
val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName)))
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq))
relation match {
case relation: MetastoreRelation =>

View file

@ -174,10 +174,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
// it is better at here to invalidate the cache to avoid confusing waring logs from the
// cache loader (e.g. cannot find data source provider, which is only defined for
// data source table.).
invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table)
invalidateTable(tableIdent)
}
def invalidateTable(databaseName: String, tableName: String): Unit = {
def invalidateTable(tableIdent: TableIdentifier): Unit = {
val databaseName = tableIdent.database.getOrElse(client.currentDatabase)
val tableName = tableIdent.table
cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase)
}
@ -187,6 +190,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
* Creates a data source table (a table created with USING clause) in Hive's metastore.
* Returns true when the table has been created. Otherwise, false.
*/
// TODO: Remove this in SPARK-10104.
def createDataSourceTable(
tableName: String,
userSpecifiedSchema: Option[StructType],
@ -203,7 +207,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
isExternal)
}
private def createDataSourceTable(
def createDataSourceTable(
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
partitionColumns: Array[String],
@ -371,10 +375,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
}
def hiveDefaultTableFilePath(tableName: String): String = {
hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName))
}
def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = {
// Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName)
val database = tableIdent.database.getOrElse(client.currentDatabase)
new Path(
new Path(client.getDatabase(client.currentDatabase).location),
tableName.toLowerCase).toString
new Path(client.getDatabase(database).location),
tableIdent.table.toLowerCase).toString
}
def tableExists(tableIdentifier: Seq[String]): Boolean = {
@ -635,7 +645,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
CreateTableUsingAsSelect(
desc.name,
TableIdentifier(desc.name),
hive.conf.defaultDataSourceName,
temporary = false,
Array.empty[String],

View file

@ -83,14 +83,16 @@ private[hive] trait HiveStrategies {
object HiveDDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case CreateTableUsing(
tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
ExecutedCommand(
CreateMetastoreDataSource(
tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil
case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) =>
tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
val cmd =
CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query)
CreateMetastoreDataSource(
tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)
ExecutedCommand(cmd) :: Nil
case CreateTableUsingAsSelect(
tableIdent, provider, false, partitionCols, mode, opts, query) =>
val cmd =
CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case _ => Nil

View file

@ -17,7 +17,9 @@
package org.apache.spark.sql.hive.execution
import org.apache.hadoop.hive.metastore.MetaStoreUtils
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser}
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@ -120,9 +122,10 @@ case class AddFile(path: String) extends RunnableCommand {
}
}
// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104).
private[hive]
case class CreateMetastoreDataSource(
tableName: String,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String],
@ -130,9 +133,24 @@ case class CreateMetastoreDataSource(
managedIfNoPath: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
// Since we are saving metadata to metastore, we need to check if metastore supports
// the table name and database name we have for this query. MetaStoreUtils.validateName
// is the method used by Hive to check if a table name or a database name is valid for
// the metastore.
if (!MetaStoreUtils.validateName(tableIdent.table)) {
throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " +
s"metastore. Metastore only accepts table name containing characters, numbers and _.")
}
if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) {
throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " +
s"for metastore. Metastore only accepts database name containing " +
s"characters, numbers and _.")
}
val tableName = tableIdent.unquotedString
val hiveContext = sqlContext.asInstanceOf[HiveContext]
if (hiveContext.catalog.tableExists(tableName :: Nil)) {
if (hiveContext.catalog.tableExists(tableIdent.toSeq)) {
if (allowExisting) {
return Seq.empty[Row]
} else {
@ -144,13 +162,13 @@ case class CreateMetastoreDataSource(
val optionsWithPath =
if (!options.contains("path") && managedIfNoPath) {
isExternal = false
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName))
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent))
} else {
options
}
hiveContext.catalog.createDataSourceTable(
tableName,
tableIdent,
userSpecifiedSchema,
Array.empty[String],
provider,
@ -161,9 +179,10 @@ case class CreateMetastoreDataSource(
}
}
// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104).
private[hive]
case class CreateMetastoreDataSourceAsSelect(
tableName: String,
tableIdent: TableIdentifier,
provider: String,
partitionColumns: Array[String],
mode: SaveMode,
@ -171,19 +190,34 @@ case class CreateMetastoreDataSourceAsSelect(
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
// Since we are saving metadata to metastore, we need to check if metastore supports
// the table name and database name we have for this query. MetaStoreUtils.validateName
// is the method used by Hive to check if a table name or a database name is valid for
// the metastore.
if (!MetaStoreUtils.validateName(tableIdent.table)) {
throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " +
s"metastore. Metastore only accepts table name containing characters, numbers and _.")
}
if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) {
throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " +
s"for metastore. Metastore only accepts database name containing " +
s"characters, numbers and _.")
}
val tableName = tableIdent.unquotedString
val hiveContext = sqlContext.asInstanceOf[HiveContext]
var createMetastoreTable = false
var isExternal = true
val optionsWithPath =
if (!options.contains("path")) {
isExternal = false
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName))
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent))
} else {
options
}
var existingSchema = None: Option[StructType]
if (sqlContext.catalog.tableExists(Seq(tableName))) {
if (sqlContext.catalog.tableExists(tableIdent.toSeq)) {
// Check if we need to throw an exception or just return.
mode match {
case SaveMode.ErrorIfExists =>
@ -200,7 +234,7 @@ case class CreateMetastoreDataSourceAsSelect(
val resolved = ResolvedDataSource(
sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath)
val createdRelation = LogicalRelation(resolved.relation)
EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match {
EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match {
case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) =>
if (l.relation != createdRelation.relation) {
val errorDescription =
@ -249,7 +283,7 @@ case class CreateMetastoreDataSourceAsSelect(
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
hiveContext.catalog.createDataSourceTable(
tableName,
tableIdent,
Some(resolved.relation.schema),
partitionColumns,
provider,
@ -258,7 +292,7 @@ case class CreateMetastoreDataSourceAsSelect(
}
// Refresh the cache of the table in the catalog.
hiveContext.refreshTable(tableName)
hiveContext.catalog.refreshTable(tableIdent)
Seq.empty[Row]
}
}

View file

@ -34,7 +34,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
@ -42,7 +41,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
override def afterAll(): Unit = {
catalog.unregisterTable(Seq("ListTablesSuiteTable"))
catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
@ -55,7 +53,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hivelisttablessuitetable'"),
Row("hivelisttablessuitetable", false))
@ -69,9 +66,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
allTables.filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
allTables.filter("tableName = 'indblisttablessuitetable'"),
Row("indblisttablessuitetable", true))
assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
checkAnswer(
allTables.filter("tableName = 'hiveindblisttablessuitetable'"),

View file

@ -19,14 +19,22 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode}
class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
override val _sqlContext: SQLContext = TestHive
override val _sqlContext: HiveContext = TestHive
private val sqlContext = _sqlContext
private val df = sqlContext.range(10).coalesce(1)
private def checkTablePath(dbName: String, tableName: String): Unit = {
// val hiveContext = sqlContext.asInstanceOf[HiveContext]
val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName)
val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName
assert(metastoreTable.serdeProperties("path") === expectedPath)
}
test(s"saveAsTable() to non-default database - with USE - Overwrite") {
withTempDatabase { db =>
activateDatabase(db) {
@ -37,6 +45,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df)
checkTablePath(db, "t")
}
}
@ -45,6 +55,58 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df)
checkTablePath(db, "t")
}
}
test(s"createExternalTable() to non-default database - with USE") {
withTempDatabase { db =>
activateDatabase(db) {
withTempPath { dir =>
val path = dir.getCanonicalPath
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
sqlContext.createExternalTable("t", path, "parquet")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table("t"), df)
sql(
s"""
|CREATE TABLE t1
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
assert(sqlContext.tableNames(db).contains("t1"))
checkAnswer(sqlContext.table("t1"), df)
}
}
}
}
test(s"createExternalTable() to non-default database - without USE") {
withTempDatabase { db =>
withTempPath { dir =>
val path = dir.getCanonicalPath
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
sqlContext.createExternalTable(s"$db.t", path, "parquet")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df)
sql(
s"""
|CREATE TABLE $db.t1
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
assert(sqlContext.tableNames(db).contains("t1"))
checkAnswer(sqlContext.table(s"$db.t1"), df)
}
}
}
@ -59,6 +121,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkTablePath(db, "t")
}
}
@ -68,6 +132,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
assert(sqlContext.tableNames(db).contains("t"))
checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df))
checkTablePath(db, "t")
}
}
@ -130,7 +196,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
}
}
test("Refreshes a table in a non-default database") {
test("Refreshes a table in a non-default database - with USE") {
import org.apache.spark.sql.functions.lit
withTempDatabase { db =>
@ -151,8 +217,94 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
sql("ALTER TABLE t ADD PARTITION (p=1)")
sql("REFRESH TABLE t")
checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1)))
df.write.parquet(s"$path/p=2")
sql("ALTER TABLE t ADD PARTITION (p=2)")
sqlContext.refreshTable("t")
checkAnswer(
sqlContext.table("t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
}
}
}
}
test("Refreshes a table in a non-default database - without USE") {
import org.apache.spark.sql.functions.lit
withTempDatabase { db =>
withTempPath { dir =>
val path = dir.getCanonicalPath
sql(
s"""CREATE EXTERNAL TABLE $db.t (id BIGINT)
|PARTITIONED BY (p INT)
|STORED AS PARQUET
|LOCATION '$path'
""".stripMargin)
checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame)
df.write.parquet(s"$path/p=1")
sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)")
sql(s"REFRESH TABLE $db.t")
checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1)))
df.write.parquet(s"$path/p=2")
sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)")
sqlContext.refreshTable(s"$db.t")
checkAnswer(
sqlContext.table(s"$db.t"),
df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2))))
}
}
}
test("invalid database name and table names") {
{
val message = intercept[AnalysisException] {
df.write.format("parquet").saveAsTable("`d:b`.`t:a`")
}.getMessage
assert(message.contains("is not a valid name for metastore"))
}
{
val message = intercept[AnalysisException] {
df.write.format("parquet").saveAsTable("`d:b`.`table`")
}.getMessage
assert(message.contains("is not a valid name for metastore"))
}
withTempPath { dir =>
val path = dir.getCanonicalPath
{
val message = intercept[AnalysisException] {
sql(
s"""
|CREATE TABLE `d:b`.`t:a` (a int)
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
}.getMessage
assert(message.contains("is not a valid name for metastore"))
}
{
val message = intercept[AnalysisException] {
sql(
s"""
|CREATE TABLE `d:b`.`table` (a int)
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
}.getMessage
assert(message.contains("is not a valid name for metastore"))
}
}
}
}

View file

@ -1138,4 +1138,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
Row(CalendarInterval.fromString(
"interval 4 minutes 59 seconds 889 milliseconds 987 microseconds")))
}
test("specifying database name for a temporary table is not allowed") {
withTempPath { dir =>
val path = dir.getCanonicalPath
val df =
sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str")
df
.write
.format("parquet")
.save(path)
val message = intercept[AnalysisException] {
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE db.t
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
}.getMessage
assert(message.contains("Specifying database name or other qualifiers are not allowed"))
// If you use backticks to quote the name of a temporary table having dot in it.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
|USING parquet
|OPTIONS (
| path '$path'
|)
""".stripMargin)
checkAnswer(sqlContext.table("`db.t`"), df)
}
}
}