From 8fc267ab3322e46db81e725a5cb1adb5a71b2b4d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 20 Apr 2016 12:58:48 -0700 Subject: [PATCH] [SPARK-14720][SPARK-13643] Move Hive-specific methods into HiveSessionState and Create a SparkSession class ## What changes were proposed in this pull request? This PR has two main changes. 1. Move Hive-specific methods from HiveContext to HiveSessionState, which help the work of removing HiveContext. 2. Create a SparkSession Class, which will later be the entry point of Spark SQL users. ## How was this patch tested? Existing tests This PR is trying to fix test failures of https://github.com/apache/spark/pull/12485. Author: Andrew Or Author: Yin Huai Closes #12522 from yhuai/spark-session. --- .../spark/internal/config/package.scala | 7 + .../org/apache/spark/sql/SQLContext.scala | 52 ++-- .../org/apache/spark/sql/SparkSession.scala | 100 +++++++ .../spark/sql/internal/SessionState.scala | 54 +++- .../SparkExecuteStatementOperation.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 7 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 +- .../thriftserver/SparkSQLSessionManager.scala | 2 +- .../server/SparkSQLOperationManager.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 6 +- .../HiveWindowFunctionQuerySuite.scala | 31 +- .../apache/spark/sql/hive/HiveContext.scala | 223 ++------------ .../spark/sql/hive/HiveMetastoreCatalog.scala | 30 +- .../spark/sql/hive/HiveQueryExecution.scala | 66 +++++ .../spark/sql/hive/HiveSessionCatalog.scala | 11 +- .../spark/sql/hive/HiveSessionState.scala | 182 ++++++++++-- .../spark/sql/hive/HiveStrategies.scala | 11 +- .../apache/spark/sql/hive/TableReader.scala | 17 +- .../sql/hive/client/HiveClientImpl.scala | 2 + .../hive/execution/CreateTableAsSelect.scala | 9 +- .../hive/execution/CreateViewAsSelect.scala | 11 +- .../hive/execution/HiveNativeCommand.scala | 8 +- .../sql/hive/execution/HiveSqlParser.scala | 19 +- .../sql/hive/execution/HiveTableScan.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 23 +- .../hive/execution/ScriptTransformation.scala | 9 +- .../spark/sql/hive/execution/commands.scala | 38 ++- .../apache/spark/sql/hive/test/TestHive.scala | 275 +++++++++++------- .../spark/sql/hive/ErrorPositionSuite.scala | 3 +- .../spark/sql/hive/HiveContextSuite.scala | 7 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 33 +-- .../spark/sql/hive/MultiDatabaseSuite.scala | 9 +- .../spark/sql/hive/StatisticsSuite.scala | 4 +- .../execution/BigDataBenchmarkSuite.scala | 4 +- .../hive/execution/HiveComparisonTest.scala | 25 +- .../sql/hive/execution/HiveQuerySuite.scala | 6 +- .../sql/hive/execution/HiveSerDeSuite.scala | 4 +- .../sql/hive/execution/PruningSuite.scala | 18 +- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- .../execution/ScriptTransformationSuite.scala | 8 +- .../apache/spark/sql/hive/parquetSuites.scala | 6 +- .../spark/sql/sources/BucketedReadSuite.scala | 2 +- 43 files changed, 799 insertions(+), 549 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 94b50ee065..2c1e0b71e3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -89,4 +89,11 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + // Note: This is a SQL config but needs to be in core because the REPL depends on it + private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 781d699819..f3f84144ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -63,14 +63,18 @@ import org.apache.spark.util.Utils * @since 1.0.0 */ class SQLContext private[sql]( - @transient protected[sql] val sharedState: SharedState, + @transient private val sparkSession: SparkSession, val isRootContext: Boolean) extends Logging with Serializable { self => + private[sql] def this(sparkSession: SparkSession) = { + this(sparkSession, true) + } + def this(sc: SparkContext) = { - this(new SharedState(sc), true) + this(new SparkSession(sc)) } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) @@ -97,12 +101,15 @@ class SQLContext private[sql]( } } - def sparkContext: SparkContext = sharedState.sparkContext - + protected[sql] def sessionState: SessionState = sparkSession.sessionState + protected[sql] def sharedState: SharedState = sparkSession.sharedState + protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def cacheManager: CacheManager = sharedState.cacheManager protected[sql] def listener: SQLListener = sharedState.listener protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog + def sparkContext: SparkContext = sharedState.sparkContext + /** * Returns a [[SQLContext]] as new session, with separated SQL configurations, temporary * tables, registered functions, but sharing the same [[SparkContext]], cached data and @@ -110,14 +117,9 @@ class SQLContext private[sql]( * * @since 1.6.0 */ - def newSession(): SQLContext = new SQLContext(sharedState, isRootContext = false) - - /** - * Per-session state, e.g. configuration, functions, temporary tables etc. - */ - @transient - protected[sql] lazy val sessionState: SessionState = new SessionState(self) - protected[spark] def conf: SQLConf = sessionState.conf + def newSession(): SQLContext = { + new SQLContext(sparkSession.newSession(), isRootContext = false) + } /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -132,10 +134,14 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(props: Properties): Unit = conf.setConf(props) + def setConf(props: Properties): Unit = sessionState.setConf(props) - /** Set the given Spark SQL configuration property. */ - private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = conf.setConf(entry, value) + /** + * Set the given Spark SQL configuration property. + */ + private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + sessionState.setConf(entry, value) + } /** * Set the given Spark SQL configuration property. @@ -143,7 +149,7 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = conf.setConfString(key, value) + def setConf(key: String, value: String): Unit = sessionState.setConf(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -186,23 +192,19 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - // Extract `spark.sql.*` entries and put it in our SQLConf. - // Subclasses may additionally set these entries in other confs. - SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) => - setConf(k, v) - } - protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new QueryExecution(this, plan) + protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { + sessionState.executePlan(plan) + } /** * Add a jar to SQLContext */ protected[sql] def addJar(path: String): Unit = { - sparkContext.addJar(path) + sessionState.addJar(path) } /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ @@ -768,7 +770,7 @@ class SQLContext private[sql]( * as Spark can parse all supported Hive DDLs itself. */ private[sql] def runNativeSql(sqlText: String): Seq[Row] = { - throw new UnsupportedOperationException + sessionState.runNativeSql(sqlText).map { r => Row(r) } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala new file mode 100644 index 0000000000..17ba299825 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.{SessionState, SharedState} +import org.apache.spark.util.Utils + + +/** + * The entry point to Spark execution. + */ +class SparkSession private( + sparkContext: SparkContext, + existingSharedState: Option[SharedState]) { self => + + def this(sc: SparkContext) { + this(sc, None) + } + + /** + * Start a new session where configurations, temp tables, temp functions etc. are isolated. + */ + def newSession(): SparkSession = { + // Note: materialize the shared state here to ensure the parent and child sessions are + // initialized with the same shared state. + new SparkSession(sparkContext, Some(sharedState)) + } + + @transient + protected[sql] lazy val sharedState: SharedState = { + existingSharedState.getOrElse( + SparkSession.reflect[SharedState, SparkContext]( + SparkSession.sharedStateClassName(sparkContext.conf), + sparkContext)) + } + + @transient + protected[sql] lazy val sessionState: SessionState = { + SparkSession.reflect[SessionState, SQLContext]( + SparkSession.sessionStateClassName(sparkContext.conf), + new SQLContext(self, isRootContext = false)) + } + +} + + +private object SparkSession { + + private def sharedStateClassName(conf: SparkConf): String = { + conf.get(CATALOG_IMPLEMENTATION) match { + case "hive" => "org.apache.spark.sql.hive.HiveSharedState" + case "in-memory" => classOf[SharedState].getCanonicalName + } + } + + private def sessionStateClassName(conf: SparkConf): String = { + conf.get(CATALOG_IMPLEMENTATION) match { + case "hive" => "org.apache.spark.sql.hive.HiveSessionState" + case "in-memory" => classOf[SessionState].getCanonicalName + } + } + + /** + * Helper method to create an instance of [[T]] using a single-arg constructor that + * accepts an [[Arg]]. + */ + private def reflect[T, Arg <: AnyRef]( + className: String, + ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = { + try { + val clazz = Utils.classForName(className) + val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass) + ctor.newInstance(ctorArg).asInstanceOf[T] + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Error while instantiating '$className':", e) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index d404a7c0ae..42915d5887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.internal +import java.util.Properties + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource} import org.apache.spark.sql.util.ExecutionListenerManager + /** * A class that holds all session-specific state in a given [[SQLContext]]. */ @@ -37,7 +44,10 @@ private[sql] class SessionState(ctx: SQLContext) { /** * SQL-specific key-value configurations. */ - lazy val conf = new SQLConf + lazy val conf: SQLConf = new SQLConf + + // Automatically extract `spark.sql.*` entries and put it in our SQLConf + setConf(SQLContext.getSQLProperties(ctx.sparkContext.getConf)) lazy val experimentalMethods = new ExperimentalMethods @@ -101,5 +111,45 @@ private[sql] class SessionState(ctx: SQLContext) { * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. */ lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) -} + + // ------------------------------------------------------ + // Helper methods, partially leftover from pre-2.0 days + // ------------------------------------------------------ + + def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(ctx, plan) + + def refreshTable(tableName: String): Unit = { + catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) + } + + def invalidateTable(tableName: String): Unit = { + catalog.invalidateTable(sqlParser.parseTableIdentifier(tableName)) + } + + final def setConf(properties: Properties): Unit = { + properties.asScala.foreach { case (k, v) => setConf(k, v) } + } + + final def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + conf.setConf(entry, value) + setConf(entry.key, entry.stringConverter(value)) + } + + def setConf(key: String, value: String): Unit = { + conf.setConfString(key, value) + } + + def addJar(path: String): Unit = { + ctx.sparkContext.addJar(path) + } + + def analyze(tableName: String): Unit = { + throw new UnsupportedOperationException + } + + def runNativeSql(sql: String): Seq[String] = { + throw new UnsupportedOperationException + } + +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 673a293ce2..d89c3b4ab2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -195,7 +195,7 @@ private[hive] class SparkExecuteStatementOperation( setState(OperationState.RUNNING) // Always use the latest class loader provided by executionHive's state. val executionHiveClassLoader = - hiveContext.executionHive.state.getConf.getClassLoader + hiveContext.sessionState.executionHive.state.getConf.getClassLoader Thread.currentThread().setContextClassLoader(executionHiveClassLoader) HiveThriftServer2.listener.onStatementStart( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index b8bc8ea44d..7e8eada5ad 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, HiveQueryExecution} private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) @@ -41,7 +41,7 @@ private[hive] class SparkSQLDriver( override def init(): Unit = { } - private def getResultSetSchema(query: context.QueryExecution): Schema = { + private def getResultSetSchema(query: HiveQueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") if (analyzed.output.isEmpty) { @@ -59,7 +59,8 @@ private[hive] class SparkSQLDriver( // TODO unify the error code try { context.sparkContext.setJobDescription(command) - val execution = context.executePlan(context.sql(command).logicalPlan) + val execution = + context.executePlan(context.sql(command).logicalPlan).asInstanceOf[HiveQueryExecution] hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index ae1d737b58..2679ac1854 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -58,9 +58,9 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) - hiveContext.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) - hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) - hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index de4e9c62b5..f492b5656c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -71,7 +71,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - val ctx = if (hiveContext.hiveThriftServerSingleSession) { + val ctx = if (hiveContext.sessionState.hiveThriftServerSingleSession) { hiveContext } else { hiveContext.newSession() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 0c468a408b..da410c68c8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager() confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { val hiveContext = sessionToContexts(parentSession.getSessionHandle) - val runInBackground = async && hiveContext.hiveThriftServerAsync + val runInBackground = async && hiveContext.sessionState.hiveThriftServerAsync val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 989e68aebe..49fd198730 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.convertMetastoreOrc + private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -47,7 +47,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def beforeAll() { super.beforeAll() - TestHive.cacheTables = true + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -66,7 +66,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def afterAll() { try { - TestHive.cacheTables = false + TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index d0b4cbe401..de592f8d93 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -38,7 +38,8 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + super.beforeAll() + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -100,11 +101,14 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - TestHive.reset() - super.afterAll() + try { + TestHive.setCacheTables(false) + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } finally { + super.afterAll() + } } ///////////////////////////////////////////////////////////////////////////// @@ -773,7 +777,8 @@ class HiveWindowFunctionQueryFileSuite private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + super.beforeAll() + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -790,10 +795,14 @@ class HiveWindowFunctionQueryFileSuite } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - TestHive.reset() + try { + TestHive.setCacheTables(false) + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } finally { + super.afterAll() + } } override def blackList: Seq[String] = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b26a9ab699..b2ce3e0df2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,7 +22,6 @@ import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit -import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap @@ -32,26 +31,18 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.internal.{SharedState, SQLConf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -61,225 +52,45 @@ import org.apache.spark.util.Utils * @since 1.0.0 */ class HiveContext private[hive]( - @transient protected[hive] val hiveSharedState: HiveSharedState, - override val isRootContext: Boolean) - extends SQLContext(hiveSharedState, isRootContext) with Logging { + @transient private val sparkSession: SparkSession, + isRootContext: Boolean) + extends SQLContext(sparkSession, isRootContext) with Logging { self => def this(sc: SparkContext) = { - this(new HiveSharedState(sc), true) + this(new SparkSession(HiveContext.withHiveExternalCatalog(sc)), true) } def this(sc: JavaSparkContext) = this(sc.sc) - import org.apache.spark.sql.hive.HiveContext._ - - logDebug("create HiveContext") - /** * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF, * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader * and Hive client (both of execution and metadata) with existing HiveContext. */ override def newSession(): HiveContext = { - new HiveContext(hiveSharedState, isRootContext = false) + new HiveContext(sparkSession.newSession(), isRootContext = false) } - @transient - protected[sql] override lazy val sessionState = new HiveSessionState(self) - - protected[hive] def hiveCatalog: HiveExternalCatalog = hiveSharedState.externalCatalog - protected[hive] def executionHive: HiveClientImpl = sessionState.executionHive - protected[hive] def metadataHive: HiveClient = sessionState.metadataHive - - /** - * When true, enables an experimental feature where metastore tables that use the parquet SerDe - * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive - * SerDe. - */ - protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) - - /** - * When true, also tries to merge possibly different but compatible Parquet schemas in different - * Parquet data files. - * - * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. - */ - protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = - getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - - /** - * When true, enables an experimental feature where metastore tables that use the Orc SerDe - * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive - * SerDe. - */ - protected[sql] def convertMetastoreOrc: Boolean = getConf(CONVERT_METASTORE_ORC) - - /** - * When true, a table created by a Hive CTAS statement (no USING clause) will be - * converted to a data source table, using the data source set by spark.sql.sources.default. - * The table in CTAS statement will be converted when it meets any of the following conditions: - * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or - * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml - * is either TextFile or SequenceFile. - * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe - * is specified (no ROW FORMAT SERDE clause). - * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format - * and no SerDe is specified (no ROW FORMAT SERDE clause). - */ - protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) - - /* - * hive thrift server use background spark sql thread pool to execute sql queries - */ - protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) - - protected[hive] def hiveThriftServerSingleSession: Boolean = - sparkContext.conf.getBoolean("spark.sql.hive.thriftServer.singleSession", defaultValue = false) - - @transient - protected[sql] lazy val substitutor = new VariableSubstitution() - - /** - * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. - * - allow SQL11 keywords to be used as identifiers - */ - private[sql] def defaultOverrides() = { - setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + protected[sql] override def sessionState: HiveSessionState = { + sparkSession.sessionState.asInstanceOf[HiveSessionState] } - defaultOverrides() - - protected[sql] override def parseSql(sql: String): LogicalPlan = { - executionHive.withHiveState { - super.parseSql(substitutor.substitute(sessionState.hiveconf, sql)) - } + protected[sql] override def sharedState: HiveSharedState = { + sparkSession.sharedState.asInstanceOf[HiveSharedState] } - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - /** - * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. - * - * @since 1.3.0 - */ - def refreshTable(tableName: String): Unit = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - sessionState.catalog.refreshTable(tableIdent) - } - - protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - sessionState.catalog.invalidateTable(tableIdent) - } - - /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. - * - * Right now, it only supports Hive tables and it only updates the size of a Hive table - * in the Hive metastore. - * - * @since 1.2.0 - */ - def analyze(tableName: String) { - AnalyzeTable(tableName).run(self) - } - - override def setConf(key: String, value: String): Unit = { - super.setConf(key, value) - executionHive.runSqlHive(s"SET $key=$value") - metadataHive.runSqlHive(s"SET $key=$value") - // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), - // this setConf will be called in the constructor of the SQLContext. - // Also, calling hiveconf will create a default session containing a HiveConf, which - // will interfer with the creation of executionHive (which is a lazy val). So, - // we put hiveconf.set at the end of this method. - sessionState.hiveconf.set(key, value) - } - - override private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { - setConf(entry.key, entry.stringConverter(value)) - } - - private def functionOrMacroDDLPattern(command: String) = Pattern.compile( - ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL).matcher(command) - - protected[hive] def runSqlHive(sql: String): Seq[String] = { - val command = sql.trim.toLowerCase - if (functionOrMacroDDLPattern(command).matches()) { - executionHive.runSqlHive(sql) - } else if (command.startsWith("set")) { - metadataHive.runSqlHive(sql) - executionHive.runSqlHive(sql) - } else { - metadataHive.runSqlHive(sql) - } - } - - /** - * Executes a SQL query without parsing it, but instead passing it directly to Hive. - * This is currently only used for DDLs and will be removed as soon as Spark can parse - * all supported Hive DDLs itself. - */ - protected[sql] override def runNativeSql(sqlText: String): Seq[Row] = { - runSqlHive(sqlText).map { s => Row(s) } - } - - /** Extends QueryExecution with hive specific features. */ - protected[sql] class QueryExecution(logicalPlan: LogicalPlan) - extends org.apache.spark.sql.execution.QueryExecution(this, logicalPlan) { - - /** - * Returns the result as a hive compatible sequence of strings. For native commands, the - * execution is simply passed back to Hive. - */ - def stringResult(): Seq[String] = executedPlan match { - case ExecutedCommand(desc: DescribeHiveTableCommand) => - // If it is a describe command for a Hive table, we want to have the output format - // be similar with Hive. - desc.run(self).map { - case Row(name: String, dataType: String, comment) => - Seq(name, dataType, - Option(comment.asInstanceOf[String]).getOrElse("")) - .map(s => String.format(s"%-20s", s)) - .mkString("\t") - } - case command: ExecutedCommand => - command.executeCollect().map(_.getString(0)) - - case other => - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = analyzed.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq - } - - override def simpleString: String = - logical match { - case _: HiveNativeCommand => "" - case _: SetCommand => "" - case _ => super.simpleString - } - } - - protected[sql] override def addJar(path: String): Unit = { - // Add jar to Hive and classloader - executionHive.addJar(path) - metadataHive.addJar(path) - Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) - super.addJar(path) - } } private[hive] object HiveContext extends Logging { + + def withHiveExternalCatalog(sc: SparkContext): SparkContext = { + sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + sc + } + /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 753950ff84..33a926e4d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -115,17 +115,16 @@ private[hive] object HiveSerDe { * This is still used for things like creating data source tables, but in the future will be * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ -private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) - extends Logging { - - val conf = hive.conf +private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { + private val conf = hive.conf + private val sessionState = hive.sessionState.asInstanceOf[HiveSessionState] + private val client = hive.sharedState.asInstanceOf[HiveSharedState].metadataHive + private val hiveconf = sessionState.hiveconf /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getCurrentDatabase: String = { - hive.sessionState.catalog.getCurrentDatabase - } + private def getCurrentDatabase: String = hive.sessionState.catalog.getCurrentDatabase def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( @@ -298,7 +297,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte CatalogTableType.MANAGED_TABLE } - val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.sessionState.hiveconf) + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hiveconf) val dataSource = DataSource( hive, @@ -600,14 +599,14 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte object ParquetConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && - hive.convertMetastoreParquet + sessionState.convertMetastoreParquet } private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { val defaultSource = new ParquetDefaultSource() val fileFormatClass = classOf[ParquetDefaultSource] - val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging val options = Map( ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( @@ -652,7 +651,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte object OrcConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && - hive.convertMetastoreOrc + sessionState.convertMetastoreOrc } private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { @@ -727,7 +726,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val desc = table.copy(schema = schema) - if (hive.convertCTAS && table.storage.serde.isEmpty) { + if (sessionState.convertCTAS && table.storage.serde.isEmpty) { // Do the conversion when spark.sql.hive.convertCTAS is true and the query // does not specify any storage format (file format and storage handler). if (table.identifier.database.isDefined) { @@ -815,14 +814,13 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte * the information from the metastore. */ class MetaStoreFileCatalog( - hive: HiveContext, + ctx: SQLContext, paths: Seq[Path], partitionSpecFromHive: PartitionSpec) - extends HDFSFileCatalog(hive, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { - + extends HDFSFileCatalog(ctx, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { override def getStatus(path: Path): Array[FileStatus] = { - val fs = path.getFileSystem(hive.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) fs.listStatus(path) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala new file mode 100644 index 0000000000..1c1bfb610c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} + + +/** + * A [[QueryExecution]] with hive specific features. + */ +protected[hive] class HiveQueryExecution(ctx: SQLContext, logicalPlan: LogicalPlan) + extends QueryExecution(ctx, logicalPlan) { + + /** + * Returns the result as a hive compatible sequence of strings. For native commands, the + * execution is simply passed back to Hive. + */ + def stringResult(): Seq[String] = executedPlan match { + case ExecutedCommand(desc: DescribeHiveTableCommand) => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + desc.run(ctx).map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, + Option(comment.asInstanceOf[String]).getOrElse("")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") + } + case command: ExecutedCommand => + command.executeCollect().map(_.getString(0)) + + case other => + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = analyzed.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq + } + + override def simpleString: String = + logical match { + case _: HiveNativeCommand => "" + case _: SetCommand => "" + case _ => super.simpleString + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f91393fc76..4f9513389c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCat import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient @@ -45,10 +45,11 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, - context: HiveContext, + context: SQLContext, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: SQLConf) + conf: SQLConf, + hiveconf: HiveConf) extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) { override def setCurrentDatabase(db: String): Unit = { @@ -75,7 +76,7 @@ private[sql] class HiveSessionCatalog( // ---------------------------------------------------------------- override def getDefaultDBPath(db: String): String = { - val defaultPath = context.sessionState.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) + val defaultPath = hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) new Path(new Path(defaultPath), db + ".db").toString } @@ -83,7 +84,7 @@ private[sql] class HiveSessionCatalog( // essentially a cache for metastore tables. However, it relies on a lot of session-specific // things so it would be a lot of work to split its functionality between HiveSessionCatalog // and HiveCatalog. We should still do it at some point... - private val metastoreCatalog = new HiveMetastoreCatalog(client, context) + private val metastoreCatalog = new HiveMetastoreCatalog(context) val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 2b848524f3..09297c27dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -17,62 +17,80 @@ package org.apache.spark.sql.hive +import java.util.regex.Pattern + import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} -import org.apache.spark.sql.hive.execution.HiveSqlParser +import org.apache.spark.sql.hive.execution.{AnalyzeTable, HiveSqlParser} import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A class that holds all session-specific state in a given [[HiveContext]]. */ -private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) { +private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) { + + self => + + private lazy val sharedState: HiveSharedState = ctx.sharedState.asInstanceOf[HiveSharedState] + + /** + * A Hive client used for execution. + */ + lazy val executionHive: HiveClientImpl = sharedState.executionHive.newSession() + + /** + * A Hive client used for interacting with the metastore. + */ + lazy val metadataHive: HiveClient = sharedState.metadataHive.newSession() + + /** + * A Hive helper class for substituting variables in a SQL statement. + */ + lazy val substitutor = new VariableSubstitution + + override lazy val conf: SQLConf = new SQLConf { + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + } + /** * SQLConf and HiveConf contracts: * - * 1. create a new o.a.h.hive.ql.session.SessionState for each [[HiveContext]] + * 1. create a new o.a.h.hive.ql.session.SessionState for each HiveContext * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be * set in the SQLConf *as well as* in the HiveConf. */ lazy val hiveconf: HiveConf = { - val c = ctx.executionHive.conf - ctx.setConf(c.getAllProperties) + val c = executionHive.conf + conf.setConf(c.getAllProperties) c } - /** - * A Hive client used for execution. - */ - val executionHive: HiveClientImpl = ctx.hiveSharedState.executionHive.newSession() - - /** - * A Hive client used for interacting with the metastore. - */ - val metadataHive: HiveClient = ctx.hiveSharedState.metadataHive.newSession() - - override lazy val conf: SQLConf = new SQLConf { - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - } + setDefaultOverrideConfs() /** * Internal catalog for managing table and database states. */ override lazy val catalog = { new HiveSessionCatalog( - ctx.hiveCatalog, - ctx.metadataHive, + sharedState.externalCatalog, + metadataHive, ctx, ctx.functionResourceLoader, functionRegistry, - conf) + conf, + hiveconf) } /** @@ -96,7 +114,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Parser for HiveQl query texts. */ - override lazy val sqlParser: ParserInterface = new HiveSqlParser(hiveconf) + override lazy val sqlParser: ParserInterface = new HiveSqlParser(substitutor, hiveconf) /** * Planner that takes into account Hive-specific strategies. @@ -104,13 +122,14 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def planner: SparkPlanner = { new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) with HiveStrategies { - override val hiveContext = ctx + override val context: SQLContext = ctx + override val hiveconf: HiveConf = self.hiveconf override def strategies: Seq[Strategy] = { experimentalMethods.extraStrategies ++ Seq( FileSourceStrategy, DataSourceStrategy, - HiveCommandStrategy(ctx), + HiveCommandStrategy, HiveDDLStrategy, DDLStrategy, SpecialLimits, @@ -130,4 +149,119 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) } } + + // ------------------------------------------------------ + // Helper methods, partially leftover from pre-2.0 days + // ------------------------------------------------------ + + override def executePlan(plan: LogicalPlan): HiveQueryExecution = { + new HiveQueryExecution(ctx, plan) + } + + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + def setDefaultOverrideConfs(): Unit = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + override def setConf(key: String, value: String): Unit = { + super.setConf(key, value) + executionHive.runSqlHive(s"SET $key=$value") + metadataHive.runSqlHive(s"SET $key=$value") + hiveconf.set(key, value) + } + + override def addJar(path: String): Unit = { + super.addJar(path) + executionHive.addJar(path) + metadataHive.addJar(path) + Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) + } + + /** + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ + override def analyze(tableName: String): Unit = { + AnalyzeTable(tableName).run(ctx) + } + + /** + * Execute a SQL statement by passing the query text directly to Hive. + */ + override def runNativeSql(sql: String): Seq[String] = { + val command = sql.trim.toLowerCase + val functionOrMacroDDLPattern = Pattern.compile( + ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL) + if (functionOrMacroDDLPattern.matcher(command).matches()) { + executionHive.runSqlHive(sql) + } else if (command.startsWith("set")) { + metadataHive.runSqlHive(sql) + executionHive.runSqlHive(sql) + } else { + metadataHive.runSqlHive(sql) + } + } + + /** + * When true, enables an experimental feature where metastore tables that use the parquet SerDe + * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive + * SerDe. + */ + def convertMetastoreParquet: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_PARQUET) + } + + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + def convertMetastoreParquetWithSchemaMerging: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) + } + + /** + * When true, enables an experimental feature where metastore tables that use the Orc SerDe + * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive + * SerDe. + */ + def convertMetastoreOrc: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_ORC) + } + + /** + * When true, a table created by a Hive CTAS statement (no USING clause) will be + * converted to a data source table, using the data source set by spark.sql.sources.default. + * The table in CTAS statement will be converted when it meets any of the following conditions: + * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or + * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml + * is either TextFile or SequenceFile. + * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe + * is specified (no ROW FORMAT SERDE clause). + * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format + * and no SerDe is specified (no ROW FORMAT SERDE clause). + */ + def convertCTAS: Boolean = { + conf.getConf(HiveContext.CONVERT_CTAS) + } + + /** + * When true, Hive Thrift server will execute SQL queries asynchronously using a thread pool." + */ + def hiveThriftServerAsync: Boolean = { + conf.getConf(HiveContext.HIVE_THRIFT_SERVER_ASYNC) + } + + def hiveThriftServerSingleSession: Boolean = { + ctx.sparkContext.conf.getBoolean( + "spark.sql.hive.thriftServer.singleSession", defaultValue = false) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 010361a32e..bbdcc8c6c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -31,12 +33,13 @@ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => - val hiveContext: HiveContext + val context: SQLContext + val hiveconf: HiveConf object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => - ScriptTransformation(input, script, output, planLater(child), schema)(hiveContext) :: Nil + ScriptTransformation(input, script, output, planLater(child), schema)(hiveconf) :: Nil case _ => Nil } } @@ -74,7 +77,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(context, hiveconf)) :: Nil case _ => Nil } @@ -103,7 +106,7 @@ private[hive] trait HiveStrategies { } } - case class HiveCommandStrategy(context: HiveContext) extends Strategy { + case object HiveCommandStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case describe: DescribeCommand => ExecutedCommand( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 98a427380d..6a20d7c25b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -37,6 +37,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -61,8 +62,8 @@ private[hive] class HadoopTableReader( @transient private val attributes: Seq[Attribute], @transient private val relation: MetastoreRelation, - @transient private val sc: HiveContext, - hiveExtraConf: HiveConf) + @transient private val sc: SQLContext, + hiveconf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". @@ -72,14 +73,12 @@ class HadoopTableReader( private val _minSplitsPerRDD = if (sc.sparkContext.isLocal) { 0 // will splitted based on block by default. } else { - math.max( - sc.sessionState.hiveconf.getInt("mapred.map.tasks", 1), - sc.sparkContext.defaultMinPartitions) + math.max(hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) } - SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveExtraConf) + SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveconf) private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveconf)) override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( @@ -164,7 +163,7 @@ class HadoopTableReader( case (partition, partDeserializer) => def updateExistPathSetByPathPattern(pathPatternStr: String) { val pathPattern = new Path(pathPatternStr) - val fs = pathPattern.getFileSystem(sc.sessionState.hiveconf) + val fs = pathPattern.getFileSystem(hiveconf) val matches = fs.globStatus(pathPattern) matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) } @@ -261,7 +260,7 @@ class HadoopTableReader( private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { filterOpt match { case Some(filter) => - val fs = path.getFileSystem(sc.sessionState.hiveconf) + val fs = path.getFileSystem(hiveconf) val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) filteredFiles.mkString(",") case None => path.toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2a1fff92b5..69f7dbf6ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -151,6 +151,8 @@ private[hive] class HiveClientImpl( } /** Returns the configuration for the current session. */ + // TODO: We should not use it because HiveSessionState has a hiveconf + // for the current Session. def conf: HiveConf = SessionState.get().getConf override def getConf(key: String, defaultValue: String): String = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 29f7dc2997..ceb7f3b890 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -43,7 +43,6 @@ case class CreateTableAsSelect( override def children: Seq[LogicalPlan] = Seq(query) override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -69,24 +68,24 @@ case class CreateTableAsSelect( withFormat } - hiveContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) + sqlContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation - hiveContext.sessionState.catalog.lookupRelation(tableIdentifier) match { + sqlContext.sessionState.catalog.lookupRelation(tableIdentifier) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.sessionState.catalog.tableExists(tableIdentifier)) { + if (sqlContext.sessionState.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { throw new AnalysisException(s"$tableIdentifier already exists.") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd + sqlContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 33cd8b4480..1e234d8508 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.hive.{ HiveContext, HiveMetastoreTypes, SQLBuilder} +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveSessionState, SQLBuilder} /** * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of @@ -47,16 +46,16 @@ private[hive] case class CreateViewAsSelect( private val tableIdentifier = tableDesc.identifier override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - hiveContext.sessionState.catalog.tableExists(tableIdentifier) match { + sessionState.catalog.tableExists(tableIdentifier) match { case true if allowExisting => // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view // already exists. case true if orReplace => // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - hiveContext.metadataHive.alertView(prepareTable(sqlContext)) + sessionState.metadataHive.alertView(prepareTable(sqlContext)) case true => // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already @@ -66,7 +65,7 @@ private[hive] case class CreateViewAsSelect( "CREATE OR REPLACE VIEW AS") case false => - hiveContext.metadataHive.createView(prepareTable(sqlContext)) + sessionState.metadataHive.createView(prepareTable(sqlContext)) } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 9bb971992d..8c1f4a8dc5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.HiveSessionState import org.apache.spark.sql.types.StringType private[hive] @@ -29,6 +29,8 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[Row] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.asInstanceOf[HiveSessionState].runNativeSql(sql).map(Row(_)) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index d5d3ee43d7..4ff02cdbd0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -21,8 +21,7 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.parse.EximUtil -import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.parse.{EximUtil, VariableSubstitution} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -39,11 +38,19 @@ import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** * Concrete parser for HiveQl statements. */ -class HiveSqlParser(hiveConf: HiveConf) extends AbstractSqlParser { - val astBuilder = new HiveSqlAstBuilder(hiveConf) +class HiveSqlParser( + substitutor: VariableSubstitution, + hiveconf: HiveConf) + extends AbstractSqlParser { - override protected def nativeCommand(sqlText: String): LogicalPlan = { - HiveNativeCommand(sqlText) + val astBuilder = new HiveSqlAstBuilder(hiveconf) + + protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + super.parse(substitutor.substitute(hiveconf, command))(toResult) + } + + protected override def nativeCommand(sqlText: String): LogicalPlan = { + HiveNativeCommand(substitutor.substitute(hiveconf, sqlText)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 3c46b836dc..9a834660f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ @@ -47,7 +48,8 @@ case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Seq[Expression])( - @transient val context: HiveContext) + @transient val context: SQLContext, + @transient val hiveconf: HiveConf) extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, @@ -75,7 +77,7 @@ case class HiveTableScan( // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient - private[this] val hiveExtraConf = new HiveConf(context.sessionState.hiveconf) + private[this] val hiveExtraConf = new HiveConf(hiveconf) // append columns ids and names before broadcast addColumnMetadataToConf(hiveExtraConf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index ed538630d2..e614daadf3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -43,9 +43,10 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifNotExists: Boolean) extends UnaryNode { - @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient private lazy val hiveContext = new Context(sc.sessionState.hiveconf) - @transient private lazy val client = sc.metadataHive + @transient private val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] + @transient private val client = sessionState.metadataHive + @transient private val hiveconf = sessionState.hiveconf + @transient private lazy val hiveContext = new Context(hiveconf) def output: Seq[Attribute] = Seq.empty @@ -67,7 +68,7 @@ case class InsertIntoHiveTable( SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writerContainer.writeToFile _) + sqlContext.sparkContext.runJob(rdd, writerContainer.writeToFile _) writerContainer.commitJob() } @@ -86,17 +87,17 @@ case class InsertIntoHiveTable( val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val isCompressed = sc.sessionState.hiveconf.getBoolean( + val isCompressed = hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - sc.sessionState.hiveconf.set("mapred.output.compress", "true") + hiveconf.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(sc.sessionState.hiveconf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(sc.sessionState.hiveconf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(hiveconf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(hiveconf.get("mapred.output.compression.type")) } val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -113,12 +114,12 @@ case class InsertIntoHiveTable( // Validate partition spec if there exist any dynamic partitions if (numDynamicPartitions > 0) { // Report error if dynamic partitioning is not enabled - if (!sc.sessionState.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + if (!hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) } // Report error if dynamic partition strict mode is on but no static partition is found - if (numStaticPartitions == 0 && sc.sessionState.hiveconf.getVar( + if (numStaticPartitions == 0 && hiveconf.getVar( HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) } @@ -130,7 +131,7 @@ case class InsertIntoHiveTable( } } - val jobConf = new JobConf(sc.sessionState.hiveconf) + val jobConf = new JobConf(hiveconf) val jobConfSer = new SerializableJobConf(jobConf) // When speculation is on and output committer class name contains "Direct", we should warn diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ea48b0e5c2..2f7cec354d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe @@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} @@ -57,14 +58,14 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val hiveconf: HiveConf) extends UnaryNode { - override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveConf] = hiveconf :: Nil override def producedAttributes: AttributeSet = outputSet -- inputSet - private val serializedHiveConf = new SerializableConfiguration(sc.sessionState.hiveconf) + private val serializedHiveConf = new SerializableConfiguration(hiveconf) protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 7a2b60dde5..b5ee9a6295 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveSessionState, MetastoreRelation} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -45,8 +45,7 @@ private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val sessionState = sqlContext.sessionState - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) @@ -60,7 +59,7 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. - val stagingDir = hiveContext.metadataHive.getConf( + val stagingDir = sessionState.metadataHive.getConf( HiveConf.ConfVars.STAGINGDIR.varname, HiveConf.ConfVars.STAGINGDIR.defaultStrVal) @@ -106,7 +105,7 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { .map(_.toLong) .getOrElse(0L) val newTotalSize = - getFileSizeForTable(hiveContext.sessionState.hiveconf, relation.hiveQlTable) + getFileSizeForTable(sessionState.hiveconf, relation.hiveQlTable) // Update the Hive metastore if the total size of the table is different than the size // recorded in the Hive metastore. // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). @@ -144,9 +143,8 @@ private[hive] case class AddFile(path: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - hiveContext.runSqlHive(s"ADD FILE $path") - hiveContext.sparkContext.addFile(path) + sqlContext.sessionState.runNativeSql(s"ADD FILE $path") + sqlContext.sparkContext.addFile(path) Seq.empty[Row] } } @@ -176,9 +174,9 @@ case class CreateMetastoreDataSource( } val tableName = tableIdent.unquotedString - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - if (hiveContext.sessionState.catalog.tableExists(tableIdent)) { + if (sessionState.catalog.tableExists(tableIdent)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -190,8 +188,7 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } @@ -204,7 +201,7 @@ case class CreateMetastoreDataSource( bucketSpec = None, options = optionsWithPath).resolveRelation() - hiveContext.sessionState.catalog.createDataSourceTable( + sessionState.catalog.createDataSourceTable( tableIdent, userSpecifiedSchema, Array.empty[String], @@ -243,14 +240,13 @@ case class CreateMetastoreDataSourceAsSelect( } val tableName = tableIdent.unquotedString - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } @@ -281,14 +277,14 @@ case class CreateMetastoreDataSourceAsSelect( // inserting into (i.e. using the same compression). EliminateSubqueryAliases( - sqlContext.sessionState.catalog.lookupRelation(tableIdent)) match { + sessionState.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => existingSchema = Some(l.schema) case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + sqlContext.sql(s"DROP TABLE IF EXISTS $tableName") // Need to create the table again. createMetastoreTable = true } @@ -297,7 +293,7 @@ case class CreateMetastoreDataSourceAsSelect( createMetastoreTable = true } - val data = Dataset.ofRows(hiveContext, query) + val data = Dataset.ofRows(sqlContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. case Some(s) => data.selectExpr(s.fieldNames: _*) @@ -318,7 +314,7 @@ case class CreateMetastoreDataSourceAsSelect( // We will use the schema of resolved.relation as the schema of the table (instead of // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). - hiveContext.sessionState.catalog.createDataSourceTable( + sessionState.catalog.createDataSourceTable( tableIdent, Some(result.schema), partitionColumns, @@ -329,7 +325,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 2767528395..e629099086 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -32,16 +32,16 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.command.CacheTableCommand -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} +import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -71,42 +71,80 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext private[hive]( - testHiveSharedState: TestHiveSharedState, - val warehousePath: File, - val scratchDirPath: File, - metastoreTemporaryConf: Map[String, String], - isRootContext: Boolean) - extends HiveContext(testHiveSharedState, isRootContext) { self => +class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootContext: Boolean) + extends HiveContext(sparkSession, isRootContext) { - private def this( - sc: SparkContext, - warehousePath: File, - scratchDirPath: File, - metastoreTemporaryConf: Map[String, String]) { - this( - new TestHiveSharedState(sc, warehousePath, scratchDirPath, metastoreTemporaryConf), - warehousePath, - scratchDirPath, - metastoreTemporaryConf, - true) + def this(sc: SparkContext) { + this(new TestHiveSparkSession(HiveContext.withHiveExternalCatalog(sc)), true) } + override def newSession(): TestHiveContext = { + new TestHiveContext(sparkSession.newSession(), false) + } + + override def sharedState: TestHiveSharedState = sparkSession.sharedState + + override def sessionState: TestHiveSessionState = sparkSession.sessionState + + def setCacheTables(c: Boolean): Unit = { + sparkSession.setCacheTables(c) + } + + def getHiveFile(path: String): File = { + sparkSession.getHiveFile(path) + } + + def loadTestTable(name: String): Unit = { + sparkSession.loadTestTable(name) + } + + def reset(): Unit = { + sparkSession.reset() + } + +} + + +private[hive] class TestHiveSparkSession( + sc: SparkContext, + val warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String], + existingSharedState: Option[TestHiveSharedState]) + extends SparkSession(sc) with Logging { self => + def this(sc: SparkContext) { this( sc, Utils.createTempDir(namePrefix = "warehouse"), TestHiveContext.makeScratchDir(), - HiveContext.newTemporaryConfiguration(useInMemoryDerby = false)) + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false), + None) } - override def newSession(): HiveContext = { - new TestHiveContext( - testHiveSharedState, - warehousePath, - scratchDirPath, - metastoreTemporaryConf, - isRootContext = false) + assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive") + + // TODO: Let's remove TestHiveSharedState and TestHiveSessionState. Otherwise, + // we are not really testing the reflection logic based on the setting of + // CATALOG_IMPLEMENTATION. + @transient + override lazy val sharedState: TestHiveSharedState = { + existingSharedState.getOrElse( + new TestHiveSharedState(sc, warehousePath, scratchDirPath, metastoreTemporaryConf)) + } + + @transient + override lazy val sessionState: TestHiveSessionState = new TestHiveSessionState(self) + + override def newSession(): TestHiveSparkSession = { + new TestHiveSparkSession( + sc, warehousePath, scratchDirPath, metastoreTemporaryConf, Some(sharedState)) + } + + private var cacheTables: Boolean = false + + def setCacheTables(c: Boolean): Unit = { + cacheTables = c } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests @@ -118,9 +156,10 @@ class TestHiveContext private[hive]( // A snapshot of the entries in the starting SQLConf // We save this because tests can mutate this singleton object if they want + // This snapshot is saved when we create this TestHiveSparkSession. val initialSQLConf: SQLConf = { val snapshot = new SQLConf - conf.getAllConfs.foreach { case (k, v) => snapshot.setConfString(k, v) } + sessionState.conf.getAllConfs.foreach { case (k, v) => snapshot.setConfString(k, v) } snapshot } @@ -131,42 +170,10 @@ class TestHiveContext private[hive]( /** The location of the compiled hive distribution */ lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") - // Override so we can intercept relative paths and rewrite them to point at hive. - override def runSqlHive(sql: String): Seq[String] = - super.runSqlHive(rewritePaths(substitutor.substitute(sessionState.hiveconf, sql))) - - override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - @transient - protected[sql] override lazy val sessionState = new HiveSessionState(this) { - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.map { - case (key, value) => setConfString(key, value) - } - } - } - } - - override lazy val functionRegistry = { - // We use TestHiveFunctionRegistry at here to track functions that have been explicitly - // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). - val fr = new TestHiveFunctionRegistry - org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { - case (name, (info, builder)) => fr.registerFunction(name, info, builder) - } - fr - } - } - /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists @@ -179,7 +186,7 @@ class TestHiveContext private[hive]( * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the * hive test cases assume the system is set up. */ - private def rewritePaths(cmd: String): String = + private[hive] def rewritePaths(cmd: String): String = if (cmd.toUpperCase contains "LOAD DATA") { val testDataLocation = hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) @@ -210,36 +217,11 @@ class TestHiveContext private[hive]( val describedTable = "DESCRIBE (\\w+)".r - /** - * Override QueryExecution with special debug workflow. - */ - class QueryExecution(logicalPlan: LogicalPlan) - extends super.QueryExecution(logicalPlan) { - def this(sql: String) = this(parseSql(sql)) - override lazy val analyzed = { - val describedTables = logical match { - case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheTableCommand(tbl, _, _) => tbl :: Nil - case _ => Nil - } - - // Make sure any test tables referenced are loaded. - val referencedTables = - describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val referencedTestTables = referencedTables.filter(testTables.contains) - logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") - referencedTestTables.foreach(loadTestTable) - // Proceed with analysis. - sessionState.analyzer.execute(logical) - } - } - case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { - () => new QueryExecution(sql).stringResult(): Unit + () => new TestHiveQueryExecution(sql).stringResult(): Unit } } @@ -266,19 +248,20 @@ class TestHiveContext private[hive]( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - runSqlHive( + sessionState.runNativeSql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - runSqlHive( + sessionState.runNativeSql( s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } }), TestTable("srcpart1", () => { - runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + sessionState.runNativeSql( + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - runSqlHive( + sessionState.runNativeSql( s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) @@ -289,7 +272,7 @@ class TestHiveContext private[hive]( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - runSqlHive( + sessionState.runNativeSql( s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' @@ -302,7 +285,7 @@ class TestHiveContext private[hive]( |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) - runSqlHive( + sessionState.runNativeSql( s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") }), TestTable("serdeins", @@ -415,7 +398,6 @@ class TestHiveContext private[hive]( private val loadedTables = new collection.mutable.HashSet[String] - var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -426,7 +408,7 @@ class TestHiveContext private[hive]( createCmds.foreach(_()) if (cacheTables) { - cacheTable(name) + new SQLContext(self).cacheTable(name) } } } @@ -451,11 +433,12 @@ class TestHiveContext private[hive]( } } - cacheManager.clearCache() + sharedState.cacheManager.clearCache() loadedTables.clear() sessionState.catalog.clearTempTables() sessionState.catalog.invalidateCache() - metadataHive.reset() + + sessionState.metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } @@ -464,21 +447,21 @@ class TestHiveContext private[hive]( sessionState.hiveconf.set("fs.default.name", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. - executionHive.runSqlHive("RESET") - metadataHive.runSqlHive("RESET") + sessionState.executionHive.runSqlHive("RESET") + sessionState.metadataHive.runSqlHive("RESET") // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 - runSqlHive("set hive.table.parameters.default=") - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") + sessionState.runNativeSql("set hive.table.parameters.default=") + sessionState.runNativeSql("set datanucleus.cache.collections=true") + sessionState.runNativeSql("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + sessionState.runNativeSql("set hive.metastore.partition.name.whitelist.pattern=.*") // In case a test changed any of these values, restore all the original ones here. TestHiveContext.hiveClientConfigurations( sessionState.hiveconf, warehousePath, scratchDirPath, metastoreTemporaryConf) - .foreach { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } - defaultOverrides() + .foreach { case (k, v) => sessionState.metadataHive.runSqlHive(s"SET $k=$v") } + sessionState.setDefaultOverrideConfs() sessionState.catalog.setCurrentDatabase("default") } catch { @@ -489,6 +472,40 @@ class TestHiveContext private[hive]( } + +private[hive] class TestHiveQueryExecution( + sparkSession: TestHiveSparkSession, + logicalPlan: LogicalPlan) + extends HiveQueryExecution(new SQLContext(sparkSession), logicalPlan) with Logging { + + def this(sparkSession: TestHiveSparkSession, sql: String) { + this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql)) + } + + def this(sql: String) { + this(TestHive.sparkSession, sql) + } + + override lazy val analyzed: LogicalPlan = { + val describedTables = logical match { + case HiveNativeCommand(sparkSession.describedTable(tbl)) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } + val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(sparkSession.loadTestTable) + // Proceed with analysis. + sparkSession.sessionState.analyzer.execute(logical) + } +} + + private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { private val removedFunctions = @@ -517,7 +534,43 @@ private[hive] class TestHiveSharedState( TestHiveContext.newClientForMetadata( sc.conf, sc.hadoopConfiguration, warehousePath, scratchDirPath, metastoreTemporaryConf) } +} + +private[hive] class TestHiveSessionState(sparkSession: TestHiveSparkSession) + extends HiveSessionState(new SQLContext(sparkSession)) { + + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + override def clear(): Unit = { + super.clear() + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } + } + } + + override lazy val functionRegistry: TestHiveFunctionRegistry = { + // We use TestHiveFunctionRegistry at here to track functions that have been explicitly + // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). + val fr = new TestHiveFunctionRegistry + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + fr + } + + override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = { + new TestHiveQueryExecution(sparkSession, plan) + } + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runNativeSql(sql: String): Seq[String] = { + super.runNativeSql(sparkSession.rewritePaths(substitutor.substitute(hiveconf, sql))) + } } @@ -552,7 +605,7 @@ private[hive] object TestHiveContext { /** * Configurations needed to create a [[HiveClient]]. */ - private def hiveClientConfigurations( + def hiveClientConfigurations( hiveconf: HiveConf, warehousePath: File, scratchDirPath: File, @@ -564,7 +617,7 @@ private[hive] object TestHiveContext { ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") } - private def makeScratchDir(): File = { + def makeScratchDir(): File = { val scratchDir = Utils.createTempDir(namePrefix = "scratch") scratchDir.delete() scratchDir diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index b9e7a36b41..61910b8e6b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { @@ -131,7 +130,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = hiveContext.sessionState.sqlParser.parsePlan(query) + def ast = hiveContext.parseSql(query) def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala index b644a50613..b2c0f7e0e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala @@ -28,9 +28,12 @@ class HiveContextSuite extends SparkFunSuite { val sc = TestHive.sparkContext require(sc.conf.get("spark.sql.hive.metastore.barrierPrefixes") == "org.apache.spark.sql.hive.execution.PairSerDe") - assert(TestHive.initialSQLConf.getConfString("spark.sql.hive.metastore.barrierPrefixes") == + assert(TestHive.sparkSession.initialSQLConf.getConfString( + "spark.sql.hive.metastore.barrierPrefixes") == "org.apache.spark.sql.hive.execution.PairSerDe") - assert(TestHive.metadataHive.getConf("spark.sql.hive.metastore.barrierPrefixes", "") == + // This setting should be also set in the hiveconf of the current session. + assert(TestHive.sessionState.hiveconf.get( + "spark.sql.hive.metastore.barrierPrefixes", "") == "org.apache.spark.sql.hive.execution.PairSerDe") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 8648834f0d..2a201c195f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -96,7 +96,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -129,7 +129,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } @@ -159,7 +159,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("int", "string")) checkAnswer(table("t"), Row(1, "val_1")) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d0e6870519..bbe135b2d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -253,13 +253,13 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) // Discard the cached relation. - invalidateTable("jsonTable") + sessionState.invalidateTable("jsonTable") checkAnswer( sql("SELECT * FROM jsonTable"), sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") + sessionState.invalidateTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) assert(expectedSchema === table("jsonTable").schema) @@ -347,7 +347,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv """.stripMargin) // Discard the cached relation. - invalidateTable("ctasJsonTable") + sessionState.invalidateTable("ctasJsonTable") // Schema should not be changed. assert(table("ctasJsonTable").schema === table("jsonTable").schema) @@ -422,7 +422,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) - invalidateTable("savedJsonTable") + sessionState.invalidateTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), @@ -620,7 +620,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .mode(SaveMode.Append) .saveAsTable("arrayInParquet") - refreshTable("arrayInParquet") + sessionState.refreshTable("arrayInParquet") checkAnswer( sql("SELECT a FROM arrayInParquet"), @@ -679,7 +679,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .mode(SaveMode.Append) .saveAsTable("mapInParquet") - refreshTable("mapInParquet") + sessionState.refreshTable("mapInParquet") checkAnswer( sql("SELECT a FROM mapInParquet"), @@ -707,7 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv options = Map("path" -> tempDir.getCanonicalPath), isExternal = false) - invalidateTable("wide_schema") + sessionState.invalidateTable("wide_schema") val actualSchema = table("wide_schema").schema assert(schema === actualSchema) @@ -737,9 +737,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv "spark.sql.sources.schema" -> schema.json, "EXTERNAL" -> "FALSE")) - hiveCatalog.createTable("default", hiveTable, ignoreIfExists = false) + sharedState.externalCatalog.createTable("default", hiveTable, ignoreIfExists = false) - invalidateTable(tableName) + sessionState.invalidateTable(tableName) val actualSchema = table(tableName).schema assert(schema === actualSchema) } @@ -751,8 +751,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - invalidateTable(tableName) - val metastoreTable = hiveCatalog.getTable("default", tableName) + sessionState.invalidateTable(tableName) + val metastoreTable = sharedState.externalCatalog.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt @@ -786,8 +786,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .bucketBy(8, "d", "b") .sortBy("c") .saveAsTable(tableName) - invalidateTable(tableName) - val metastoreTable = hiveCatalog.getTable("default", tableName) + sessionState.invalidateTable(tableName) + val metastoreTable = sharedState.externalCatalog.getTable("default", tableName) val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val expectedSortByColumns = StructType(df.schema("c") :: Nil) @@ -917,7 +917,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // As a proxy for verifying that the table was stored in Hive compatible format, // we verify that each column of the table is of native type StringType. - assert(hiveCatalog.getTable("default", "not_skip_hive_metadata").schema + assert(sharedState.externalCatalog.getTable("default", "not_skip_hive_metadata").schema .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == StringType)) sessionState.catalog.createDataSourceTable( @@ -931,9 +931,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // As a proxy for verifying that the table was stored in SparkSQL format, // we verify that the table has a column type as array of StringType. - assert(hiveCatalog.getTable("default", "skip_hive_metadata").schema.forall { c => - HiveMetastoreTypes.toDataType(c.dataType) == ArrayType(StringType) - }) + assert(sharedState.externalCatalog.getTable("default", "skip_hive_metadata") + .schema.forall { c => HiveMetastoreTypes.toDataType(c.dataType) == ArrayType(StringType) }) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 3c003506ef..850cb1eda5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -25,8 +25,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private lazy val df = sqlContext.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { - val metastoreTable = hiveContext.hiveCatalog.getTable(dbName, tableName) - val expectedPath = hiveContext.hiveCatalog.getDatabase(dbName).locationUri + "/" + tableName + val metastoreTable = hiveContext.sharedState.externalCatalog.getTable(dbName, tableName) + val expectedPath = + hiveContext.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName assert(metastoreTable.storage.serdeProperties("path") === expectedPath) } @@ -216,7 +217,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") - hiveContext.refreshTable("t") + hiveContext.sessionState.refreshTable("t") checkAnswer( sqlContext.table("t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) @@ -248,7 +249,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") - hiveContext.refreshTable(s"$db.t") + hiveContext.sessionState.refreshTable(s"$db.t") checkAnswer( sqlContext.table(s"$db.t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index d14c72b34b..adc7af32ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -31,7 +31,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = hiveContext.sessionState.sqlParser.parsePlan(analyzeCommand) + val parsed = hiveContext.parseSql(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o @@ -116,7 +116,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - hiveContext.analyze("tempTable") + hiveContext.sessionState.analyze("tempTable") } hiveContext.sessionState.catalog.dropTable( TableIdentifier("tempTable"), ignoreIfNotExists = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala index a3f5921a0c..c58a664189 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.hive.execution import java.io.File -import org.apache.spark.sql.hive.test.TestHive._ /** * A set of test cases based on the big-data-benchmark. * https://amplab.cs.berkeley.edu/benchmark/ */ class BigDataBenchmarkSuite extends HiveComparisonTest { - val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") + import org.apache.spark.sql.hive.test.TestHive.sparkSession._ + val testDataDirectory = new File("target" + File.separator + "big-data-benchmark-testdata") val userVisitPath = new File(testDataDirectory, "uservisits").getCanonicalPath val testTables = Seq( TestTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index e67fcbedc3..bd46cb922e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.command.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable} +import org.apache.spark.sql.hive.SQLBuilder +import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} /** * Allows the creations of tests that execute the same query against both hive @@ -141,7 +142,7 @@ abstract class HiveComparisonTest } protected def prepareAnswer( - hiveQuery: TestHive.type#QueryExecution, + hiveQuery: TestHiveQueryExecution, answer: Seq[String]): Seq[String] = { def isSorted(plan: LogicalPlan): Boolean = plan match { @@ -332,7 +333,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new TestHive.QueryExecution(_)) + val hiveQueries = queryList.map(new TestHiveQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. // Note this must only look at the logical plan as we might not be able to analyze if // other DDL has not been executed yet. @@ -352,7 +353,7 @@ abstract class HiveComparisonTest case _: ExplainCommand => // No need to execute EXPLAIN queries as we don't check the output. Nil - case _ => TestHive.runSqlHive(queryString) + case _ => TestHive.sessionState.runNativeSql(queryString) } // We need to add a new line to non-empty answers so we can differentiate Seq() @@ -382,10 +383,10 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - var query: TestHive.QueryExecution = null + var query: TestHiveQueryExecution = null try { query = { - val originalQuery = new TestHive.QueryExecution(queryString) + val originalQuery = new TestHiveQueryExecution(queryString) val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () case _: LogicalInsertIntoHiveTable => () @@ -409,7 +410,7 @@ abstract class HiveComparisonTest } try { - val queryExecution = new TestHive.QueryExecution(convertedSQL) + val queryExecution = new TestHiveQueryExecution(convertedSQL) // Trigger the analysis of this converted SQL query. queryExecution.analyzed queryExecution @@ -472,12 +473,12 @@ abstract class HiveComparisonTest // If this query is reading other tables that were created during this test run // also print out the query plans and results for those. val computedTablesMessages: String = try { - val tablesRead = new TestHive.QueryExecution(query).executedPlan.collect { + val tablesRead = new TestHiveQueryExecution(query).executedPlan.collect { case ts: HiveTableScan => ts.relation.tableName }.toSet TestHive.reset() - val executions = queryList.map(new TestHive.QueryExecution(_)) + val executions = queryList.map(new TestHiveQueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { // We should take executedPlan instead of sparkPlan, because in following codes we @@ -562,8 +563,8 @@ abstract class HiveComparisonTest // okay by running a simple query. If this fails then we halt testing since // something must have gone seriously wrong. try { - new TestHive.QueryExecution("SELECT key FROM src").stringResult() - TestHive.runSqlHive("SELECT key FROM src") + new TestHiveQueryExecution("SELECT key FROM src").stringResult() + TestHive.sessionState.runNativeSql("SELECT key FROM src") } catch { case e: Exception => logError(s"FATAL ERROR: Canary query threw $e This implies that the " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e7a1d921b..93d63f2241 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -49,7 +49,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { override def beforeAll() { super.beforeAll() - TestHive.cacheTables = true + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -58,7 +58,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { override def afterAll() { try { - TestHive.cacheTables = false + TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") @@ -1009,7 +1009,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .mkString("/") // Loads partition data to a temporary table to verify contents - val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + val path = s"${sparkSession.warehousePath}/dynamic_part_table/$partFolder/part-00000" sql("DROP TABLE IF EXISTS dp_verify") sql("CREATE TABLE dp_verify(intcol INT)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 5586a79361..b8af0b39c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -28,8 +28,8 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { override def beforeAll(): Unit = { import TestHive._ import org.apache.hadoop.hive.serde2.RegexSerDe - super.beforeAll() - TestHive.cacheTables = false + super.beforeAll() + TestHive.setCacheTables(false) sql(s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 97cb9d9720..79ac53c863 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -21,18 +21,22 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} /** * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - TestHive.cacheTables = false - // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset - // the environment to ensure all referenced tables in this suites are not cached in-memory. - // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. - TestHive.reset() + override def beforeAll(): Unit = { + super.beforeAll() + TestHive.setCacheTables(false) + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, + // need to reset the environment to ensure all referenced tables in this suites are + // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283 + // for details. + TestHive.reset() + } // Column pruning tests @@ -144,7 +148,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.QueryExecution(sql).sparkPlan + val plan = new TestHiveQueryExecution(sql).sparkPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1098e74cab..6b71e59b73 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - val originalConf = convertCTAS + val originalConf = sessionState.convertCTAS setConf(HiveContext.CONVERT_CTAS, true) @@ -731,7 +731,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) read.json(rdd).registerTempTable("data") - val originalConf = convertCTAS + val originalConf = sessionState.convertCTAS setConf(HiveContext.CONVERT_CTAS, false) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 8f163f27c9..00b5c8dd41 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -58,7 +58,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(hiveContext), + )(hiveContext.sessionState.hiveconf), rowsDf.collect()) } @@ -72,7 +72,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(hiveContext), + )(hiveContext.sessionState.hiveconf), rowsDf.collect()) } @@ -87,7 +87,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(hiveContext), + )(hiveContext.sessionState.hiveconf), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -104,7 +104,7 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(hiveContext), + )(hiveContext.sessionState.hiveconf), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 4b2b1a160a..6fa4c3334f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -461,7 +461,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. - invalidateTable("test_insert_parquet") + sessionState.invalidateTable("test_insert_parquet") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ @@ -474,7 +474,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select * from test_insert_parquet"), sql("select a, b from jt").collect()) // Invalidate the cache. - invalidateTable("test_insert_parquet") + sessionState.invalidateTable("test_insert_parquet") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. @@ -524,7 +524,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { |select b, '2015-04-02', a FROM jt """.stripMargin).collect()) - invalidateTable("test_parquet_partitioned_cache_test") + sessionState.invalidateTable("test_parquet_partitioned_cache_test") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a0be55cfba..aa6101f7b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -349,7 +349,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + val tableDir = new File(hiveContext.sparkSession.warehousePath, "bucketed_table") Utils.deleteRecursively(tableDir) df1.write.parquet(tableDir.getAbsolutePath)