[SPARK-30104][SQL][FOLLOWUP] V2 catalog named 'global_temp' should always be masked

### What changes were proposed in this pull request?

This is a follow up to #26741 to address the following:
1. V2 catalog named `global_temp` should always be masked.
2. #26741 introduces `CatalogAndIdentifer` that supersedes `CatalogObjectIdentfier`. This PR removes `CatalogObjectIdentfier` and its usages and replace them with `CatalogAndIdentifer`.
3. `CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog)` and `CatalogObjectIdentifier(catalog, ident) if isSessionCatalog(catalog)` are replaced with `NonSessionCatalogAndIdentifier` and `SessionCatalogAndIdentifier` respectively.

### Why are the changes needed?

To fix an existing with handling v2 catalog named `global_temp` and to simplify the code base.

### Does this PR introduce any user-facing change?

No

### How was this patch tested?

Added new tests.

Closes #26853 from imback82/lookup_table.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Terry Kim 2019-12-12 14:47:20 +08:00 committed by Wenchen Fan
parent b709091b4f
commit 3741a36ebf
5 changed files with 36 additions and 32 deletions

View file

@ -52,25 +52,12 @@ private[sql] trait LookupCatalog extends Logging {
}
}
/**
* Extract catalog and identifier from a multi-part identifier with the current catalog if needed.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match {
case CatalogAndMultipartIdentifier(maybeCatalog, nameParts) =>
Some((
maybeCatalog.getOrElse(currentCatalog),
Identifier.of(nameParts.init.toArray, nameParts.last)
))
}
}
/**
* Extract session catalog and identifier from a multi-part identifier.
*/
object SessionCatalogAndIdentifier {
def unapply(parts: Seq[String]): Option[(CatalogPlugin, Identifier)] = parts match {
case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
case CatalogAndIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
Some(catalog, ident)
case _ => None
}
@ -81,7 +68,7 @@ private[sql] trait LookupCatalog extends Logging {
*/
object NonSessionCatalogAndIdentifier {
def unapply(parts: Seq[String]): Option[(CatalogPlugin, Identifier)] = parts match {
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
case CatalogAndIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
Some(catalog, ident)
case _ => None
}
@ -117,7 +104,7 @@ private[sql] trait LookupCatalog extends Logging {
assert(nameParts.nonEmpty)
if (nameParts.length == 1) {
Some((currentCatalog, Identifier.of(Array(), nameParts.head)))
} else if (nameParts.length == 2 && nameParts.head.equalsIgnoreCase(globalTempDB)) {
} else if (nameParts.head.equalsIgnoreCase(globalTempDB)) {
// Conceptually global temp views are in a special reserved catalog. However, the v2 catalog
// API does not support view yet, and we have to use v1 commands to deal with global temp
// views. To simplify the implementation, we put global temp views in a special namespace
@ -139,7 +126,7 @@ private[sql] trait LookupCatalog extends Logging {
/**
* Extract legacy table identifier from a multi-part identifier.
*
* For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths.
* For legacy support only. Please use [[CatalogAndIdentifier]] instead on DSv2 code paths.
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {

View file

@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
private case class DummyCatalogPlugin(override val name: String) extends CatalogPlugin {
@ -36,7 +37,9 @@ private case class DummyCatalogPlugin(override val name: String) extends Catalog
class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
import CatalystSqlParser._
private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap
private val globalTempDB = SQLConf.get.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
private val catalogs =
Seq("prod", "test", globalTempDB).map(x => x -> DummyCatalogPlugin(x)).toMap
private val sessionCatalog = FakeV2SessionCatalog
override val catalogManager: CatalogManager = {
@ -46,13 +49,16 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.currentCatalog).thenReturn(sessionCatalog)
when(manager.v2SessionCatalog).thenReturn(sessionCatalog)
manager
}
test("catalog object identifier") {
test("catalog and identifier") {
Seq(
("tbl", sessionCatalog, Seq.empty, "tbl"),
("db.tbl", sessionCatalog, Seq("db"), "tbl"),
(s"$globalTempDB.tbl", sessionCatalog, Seq(globalTempDB), "tbl"),
(s"$globalTempDB.ns1.ns2.tbl", sessionCatalog, Seq(globalTempDB, "ns1", "ns2"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
@ -64,7 +70,7 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
case CatalogObjectIdentifier(catalog, ident) =>
case CatalogAndIdentifier(catalog, ident) =>
catalog shouldEqual expectedCatalog
ident shouldEqual Identifier.of(namespace.toArray, name)
}
@ -156,7 +162,7 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
case CatalogObjectIdentifier(catalog, ident) =>
case CatalogAndIdentifier(catalog, ident) =>
catalog shouldEqual expectedCatalog
ident shouldEqual Identifier.of(namespace.toArray, name)
}

View file

@ -339,7 +339,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
@ -357,11 +357,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val canUseV2 = lookupV2Provider().isDefined
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
case NonSessionCatalogAndIdentifier(catalog, ident) =>
insertInto(catalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
case SessionCatalogAndIdentifier(catalog, ident)
if canUseV2 && ident.namespace().length <= 1 =>
insertInto(catalog, ident)
case AsTableIdentifier(tableIdentifier) =>
@ -479,19 +479,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
case NonSessionCatalogAndIdentifier(catalog, ident) =>
saveAsTable(catalog.asTableCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
case SessionCatalogAndIdentifier(catalog, ident)
if canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(catalog.asTableCatalog, ident)
case AsTableIdentifier(tableIdentifier) =>

View file

@ -41,7 +41,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
import df.sparkSession.sessionState.analyzer.CatalogObjectIdentifier
import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier
private val df: DataFrame = ds.toDF()
@ -52,7 +52,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
private val (catalog, identifier) = {
val CatalogObjectIdentifier(catalog, identifier) = tableName
val CatalogAndIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

View file

@ -1845,6 +1845,18 @@ class DataSourceV2SQLSuite
}
}
test("SPARK-30104: v2 catalog named global_temp will be masked") {
val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName)
val e = intercept[AnalysisException] {
// Since the following multi-part name starts with `globalTempDB`, it is resolved to
// the session catalog, not the `gloabl_temp` v2 catalog.
sql(s"CREATE TABLE $globalTempDB.ns1.ns2.tbl (id bigint, data string) USING json")
}
assert(e.message.contains("global_temp.ns1.ns2.tbl is not a valid TableIdentifier"))
}
test("table name same as catalog can be used") {
withTable("testcat.testcat") {
sql(s"CREATE TABLE testcat.testcat (id bigint, data string) USING foo")