From 3741a36ebf326b56956289e06922d178982e4879 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 12 Dec 2019 14:47:20 +0800 Subject: [PATCH] [SPARK-30104][SQL][FOLLOWUP] V2 catalog named 'global_temp' should always be masked ### What changes were proposed in this pull request? This is a follow up to #26741 to address the following: 1. V2 catalog named `global_temp` should always be masked. 2. #26741 introduces `CatalogAndIdentifer` that supersedes `CatalogObjectIdentfier`. This PR removes `CatalogObjectIdentfier` and its usages and replace them with `CatalogAndIdentifer`. 3. `CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog)` and `CatalogObjectIdentifier(catalog, ident) if isSessionCatalog(catalog)` are replaced with `NonSessionCatalogAndIdentifier` and `SessionCatalogAndIdentifier` respectively. ### Why are the changes needed? To fix an existing with handling v2 catalog named `global_temp` and to simplify the code base. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added new tests. Closes #26853 from imback82/lookup_table. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../sql/connector/catalog/LookupCatalog.scala | 21 ++++--------------- .../catalog/LookupCatalogSuite.scala | 14 +++++++++---- .../apache/spark/sql/DataFrameWriter.scala | 17 +++++++-------- .../apache/spark/sql/DataFrameWriterV2.scala | 4 ++-- .../sql/connector/DataSourceV2SQLSuite.scala | 12 +++++++++++ 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 4d3aff2274..59e7805547 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -52,25 +52,12 @@ private[sql] trait LookupCatalog extends Logging { } } - /** - * Extract catalog and identifier from a multi-part identifier with the current catalog if needed. - */ - object CatalogObjectIdentifier { - def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match { - case CatalogAndMultipartIdentifier(maybeCatalog, nameParts) => - Some(( - maybeCatalog.getOrElse(currentCatalog), - Identifier.of(nameParts.init.toArray, nameParts.last) - )) - } - } - /** * Extract session catalog and identifier from a multi-part identifier. */ object SessionCatalogAndIdentifier { def unapply(parts: Seq[String]): Option[(CatalogPlugin, Identifier)] = parts match { - case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) => + case CatalogAndIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) => Some(catalog, ident) case _ => None } @@ -81,7 +68,7 @@ private[sql] trait LookupCatalog extends Logging { */ object NonSessionCatalogAndIdentifier { def unapply(parts: Seq[String]): Option[(CatalogPlugin, Identifier)] = parts match { - case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) => + case CatalogAndIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) => Some(catalog, ident) case _ => None } @@ -117,7 +104,7 @@ private[sql] trait LookupCatalog extends Logging { assert(nameParts.nonEmpty) if (nameParts.length == 1) { Some((currentCatalog, Identifier.of(Array(), nameParts.head))) - } else if (nameParts.length == 2 && nameParts.head.equalsIgnoreCase(globalTempDB)) { + } else if (nameParts.head.equalsIgnoreCase(globalTempDB)) { // Conceptually global temp views are in a special reserved catalog. However, the v2 catalog // API does not support view yet, and we have to use v1 commands to deal with global temp // views. To simplify the implementation, we put global temp views in a special namespace @@ -139,7 +126,7 @@ private[sql] trait LookupCatalog extends Logging { /** * Extract legacy table identifier from a multi-part identifier. * - * For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths. + * For legacy support only. Please use [[CatalogAndIdentifier]] instead on DSv2 code paths. */ object AsTableIdentifier { def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala index 513f7e0348..a576e66236 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.util.CaseInsensitiveStringMap private case class DummyCatalogPlugin(override val name: String) extends CatalogPlugin { @@ -36,7 +37,9 @@ private case class DummyCatalogPlugin(override val name: String) extends Catalog class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { import CatalystSqlParser._ - private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap + private val globalTempDB = SQLConf.get.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) + private val catalogs = + Seq("prod", "test", globalTempDB).map(x => x -> DummyCatalogPlugin(x)).toMap private val sessionCatalog = FakeV2SessionCatalog override val catalogManager: CatalogManager = { @@ -46,13 +49,16 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) }) when(manager.currentCatalog).thenReturn(sessionCatalog) + when(manager.v2SessionCatalog).thenReturn(sessionCatalog) manager } - test("catalog object identifier") { + test("catalog and identifier") { Seq( ("tbl", sessionCatalog, Seq.empty, "tbl"), ("db.tbl", sessionCatalog, Seq("db"), "tbl"), + (s"$globalTempDB.tbl", sessionCatalog, Seq(globalTempDB), "tbl"), + (s"$globalTempDB.ns1.ns2.tbl", sessionCatalog, Seq(globalTempDB, "ns1", "ns2"), "tbl"), ("prod.func", catalogs("prod"), Seq.empty, "func"), ("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"), ("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"), @@ -64,7 +70,7 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { case (sql, expectedCatalog, namespace, name) => inside(parseMultipartIdentifier(sql)) { - case CatalogObjectIdentifier(catalog, ident) => + case CatalogAndIdentifier(catalog, ident) => catalog shouldEqual expectedCatalog ident shouldEqual Identifier.of(namespace.toArray, name) } @@ -156,7 +162,7 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { case (sql, expectedCatalog, namespace, name) => inside(parseMultipartIdentifier(sql)) { - case CatalogObjectIdentifier(catalog, ident) => + case CatalogAndIdentifier(catalog, ident) => catalog shouldEqual expectedCatalog ident shouldEqual Identifier.of(namespace.toArray, name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 92515a0210..2b124ae260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -339,7 +339,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} + import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Util._ @@ -357,11 +357,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val canUseV2 = lookupV2Provider().isDefined session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { - case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) => + case NonSessionCatalogAndIdentifier(catalog, ident) => insertInto(catalog, ident) - case CatalogObjectIdentifier(catalog, ident) - if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 => + case SessionCatalogAndIdentifier(catalog, ident) + if canUseV2 && ident.namespace().length <= 1 => insertInto(catalog, ident) case AsTableIdentifier(tableIdentifier) => @@ -479,19 +479,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} + import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - import org.apache.spark.sql.connector.catalog.CatalogV2Util._ val session = df.sparkSession val canUseV2 = lookupV2Provider().isDefined session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { - case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) => + case NonSessionCatalogAndIdentifier(catalog, ident) => saveAsTable(catalog.asTableCatalog, ident) - case CatalogObjectIdentifier(catalog, ident) - if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 => + case SessionCatalogAndIdentifier(catalog, ident) + if canUseV2 && ident.namespace().length <= 1 => saveAsTable(catalog.asTableCatalog, ident) case AsTableIdentifier(tableIdentifier) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index cf534ab6b9..f0758809bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -41,7 +41,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Util._ - import df.sparkSession.sessionState.analyzer.CatalogObjectIdentifier + import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier private val df: DataFrame = ds.toDF() @@ -52,7 +52,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) private val (catalog, identifier) = { - val CatalogObjectIdentifier(catalog, identifier) = tableName + val CatalogAndIdentifier(catalog, identifier) = tableName (catalog.asTableCatalog, identifier) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 50ec0d775b..4c0a472edb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1845,6 +1845,18 @@ class DataSourceV2SQLSuite } } + test("SPARK-30104: v2 catalog named global_temp will be masked") { + val globalTempDB = spark.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE) + spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + + val e = intercept[AnalysisException] { + // Since the following multi-part name starts with `globalTempDB`, it is resolved to + // the session catalog, not the `gloabl_temp` v2 catalog. + sql(s"CREATE TABLE $globalTempDB.ns1.ns2.tbl (id bigint, data string) USING json") + } + assert(e.message.contains("global_temp.ns1.ns2.tbl is not a valid TableIdentifier")) + } + test("table name same as catalog can be used") { withTable("testcat.testcat") { sql(s"CREATE TABLE testcat.testcat (id bigint, data string) USING foo")