[SPARK-15075][SPARK-15345][SQL] Clean up SparkSession builder and propagate config options to existing sessions if specified
## What changes were proposed in this pull request? Currently SparkSession.Builder use SQLContext.getOrCreate. It should probably the the other way around, i.e. all the core logic goes in SparkSession, and SQLContext just calls that. This patch does that. This patch also makes sure config options specified in the builder are propagated to the existing (and of course the new) SparkSession. ## How was this patch tested? Updated tests to reflect the change, and also introduced a new SparkSessionBuilderSuite that should cover all the branches. Author: Reynold Xin <rxin@databricks.com> Closes #13200 from rxin/SPARK-15075.
This commit is contained in:
parent
17591d90e6
commit
f2ee0ed4b7
|
@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
|
|||
} catch (IOException e) {
|
||||
// expected
|
||||
}
|
||||
instance.write().context(spark.wrapped()).overwrite().save(outputPath);
|
||||
instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
|
||||
MyParams newInstance = MyParams.load(outputPath);
|
||||
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
|
||||
Assert.assertEquals("Params should be preserved.",
|
||||
|
|
|
@ -34,7 +34,10 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
|
|||
|
||||
|
||||
class SQLContext(object):
|
||||
"""Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
|
||||
"""The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
|
||||
|
||||
As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
|
||||
here for backward compatibility.
|
||||
|
||||
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
|
||||
tables, execute SQL over tables, cache tables, and read parquet files.
|
||||
|
|
|
@ -120,6 +120,8 @@ class SparkSession(object):
|
|||
def appName(self, name):
|
||||
"""Sets a name for the application, which will be shown in the Spark web UI.
|
||||
|
||||
If no application name is set, a randomly generated name will be used.
|
||||
|
||||
:param name: an application name
|
||||
"""
|
||||
return self.config("spark.app.name", name)
|
||||
|
@ -133,8 +135,17 @@ class SparkSession(object):
|
|||
|
||||
@since(2.0)
|
||||
def getOrCreate(self):
|
||||
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
|
||||
one based on the options set in this builder.
|
||||
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
|
||||
new one based on the options set in this builder.
|
||||
|
||||
This method first checks whether there is a valid thread-local SparkSession,
|
||||
and if yes, return that one. It then checks whether there is a valid global
|
||||
default SparkSession, and if yes, return that one. If no valid global default
|
||||
SparkSession exists, the method creates a new SparkSession and assigns the
|
||||
newly created SparkSession as the global default.
|
||||
|
||||
In case an existing SparkSession is returned, the config options specified
|
||||
in this builder will be applied to the existing SparkSession.
|
||||
"""
|
||||
with self._lock:
|
||||
from pyspark.conf import SparkConf
|
||||
|
@ -175,7 +186,7 @@ class SparkSession(object):
|
|||
if jsparkSession is None:
|
||||
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
|
||||
self._jsparkSession = jsparkSession
|
||||
self._jwrapped = self._jsparkSession.wrapped()
|
||||
self._jwrapped = self._jsparkSession.sqlContext()
|
||||
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
|
||||
_monkey_patch_RDD(self)
|
||||
install_exception_handler()
|
||||
|
|
|
@ -213,7 +213,7 @@ class Dataset[T] private[sql](
|
|||
private implicit def classTag = unresolvedTEncoder.clsTag
|
||||
|
||||
// sqlContext must be val because a stable identifier is expected when you import implicits
|
||||
@transient lazy val sqlContext: SQLContext = sparkSession.wrapped
|
||||
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
protected[sql] def resolve(colName: String): NamedExpression = {
|
||||
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
|
||||
|
|
|
@ -19,25 +19,22 @@ package org.apache.spark.sql
|
|||
|
||||
import java.beans.BeanInfo
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
import scala.collection.immutable
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext, SparkException}
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.annotation.{DeveloperApi, Experimental}
|
||||
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.ConfigEntry
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
|
||||
import org.apache.spark.sql.catalyst._
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.command.ShowTablesCommand
|
||||
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
|
||||
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
|
||||
import org.apache.spark.sql.sources.BaseRelation
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager
|
|||
/**
|
||||
* The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
|
||||
*
|
||||
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here
|
||||
* for backward compatibility.
|
||||
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class
|
||||
* here for backward compatibility.
|
||||
*
|
||||
* @groupname basic Basic Operations
|
||||
* @groupname ddl_ops Persistent Catalog DDL
|
||||
|
@ -76,42 +73,21 @@ class SQLContext private[sql](
|
|||
this(sparkSession, true)
|
||||
}
|
||||
|
||||
@deprecated("Use SparkSession.builder instead", "2.0.0")
|
||||
def this(sc: SparkContext) = {
|
||||
this(new SparkSession(sc))
|
||||
}
|
||||
|
||||
@deprecated("Use SparkSession.builder instead", "2.0.0")
|
||||
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
|
||||
|
||||
// TODO: move this logic into SparkSession
|
||||
|
||||
// If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
|
||||
// wants to create a new root SQLContext (a SQLContext that is not created by newSession).
|
||||
private val allowMultipleContexts =
|
||||
sparkContext.conf.getBoolean(
|
||||
SQLConf.ALLOW_MULTIPLE_CONTEXTS.key,
|
||||
SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get)
|
||||
|
||||
// Assert no root SQLContext is running when allowMultipleContexts is false.
|
||||
{
|
||||
if (!allowMultipleContexts && isRootContext) {
|
||||
SQLContext.getInstantiatedContextOption() match {
|
||||
case Some(rootSQLContext) =>
|
||||
val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " +
|
||||
s"It is recommended to use SQLContext.getOrCreate to get the instantiated " +
|
||||
s"SQLContext/HiveContext. To ignore this error, " +
|
||||
s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf."
|
||||
throw new SparkException(errMsg)
|
||||
case None => // OK
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected[sql] def sessionState: SessionState = sparkSession.sessionState
|
||||
protected[sql] def sharedState: SharedState = sparkSession.sharedState
|
||||
protected[sql] def conf: SQLConf = sessionState.conf
|
||||
protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf
|
||||
protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
|
||||
protected[sql] def listener: SQLListener = sparkSession.listener
|
||||
protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog
|
||||
|
||||
def sparkContext: SparkContext = sparkSession.sparkContext
|
||||
|
@ -123,7 +99,7 @@ class SQLContext private[sql](
|
|||
*
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def newSession(): SQLContext = sparkSession.newSession().wrapped
|
||||
def newSession(): SQLContext = sparkSession.newSession().sqlContext
|
||||
|
||||
/**
|
||||
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
|
||||
|
@ -760,21 +736,6 @@ class SQLContext private[sql](
|
|||
schema: StructType): DataFrame = {
|
||||
sparkSession.applySchemaToPythonRDD(rdd, schema)
|
||||
}
|
||||
|
||||
// TODO: move this logic into SparkSession
|
||||
|
||||
// Register a successfully instantiated context to the singleton. This should be at the end of
|
||||
// the class definition so that the singleton is updated only if there is no exception in the
|
||||
// construction of the instance.
|
||||
sparkContext.addSparkListener(new SparkListener {
|
||||
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
|
||||
SQLContext.clearInstantiatedContext()
|
||||
SQLContext.clearSqlListener()
|
||||
}
|
||||
})
|
||||
|
||||
sparkSession.setWrappedContext(self)
|
||||
SQLContext.setInstantiatedContext(self)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -787,19 +748,6 @@ class SQLContext private[sql](
|
|||
*/
|
||||
object SQLContext {
|
||||
|
||||
/**
|
||||
* The active SQLContext for the current thread.
|
||||
*/
|
||||
private val activeContext: InheritableThreadLocal[SQLContext] =
|
||||
new InheritableThreadLocal[SQLContext]
|
||||
|
||||
/**
|
||||
* Reference to the created SQLContext.
|
||||
*/
|
||||
@transient private val instantiatedContext = new AtomicReference[SQLContext]()
|
||||
|
||||
@transient private val sqlListener = new AtomicReference[SQLListener]()
|
||||
|
||||
/**
|
||||
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
|
||||
*
|
||||
|
@ -811,41 +759,9 @@ object SQLContext {
|
|||
*
|
||||
* @since 1.5.0
|
||||
*/
|
||||
@deprecated("Use SparkSession.builder instead", "2.0.0")
|
||||
def getOrCreate(sparkContext: SparkContext): SQLContext = {
|
||||
val ctx = activeContext.get()
|
||||
if (ctx != null && !ctx.sparkContext.isStopped) {
|
||||
return ctx
|
||||
}
|
||||
|
||||
synchronized {
|
||||
val ctx = instantiatedContext.get()
|
||||
if (ctx == null || ctx.sparkContext.isStopped) {
|
||||
new SQLContext(sparkContext)
|
||||
} else {
|
||||
ctx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def clearInstantiatedContext(): Unit = {
|
||||
instantiatedContext.set(null)
|
||||
}
|
||||
|
||||
private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
|
||||
synchronized {
|
||||
val ctx = instantiatedContext.get()
|
||||
if (ctx == null || ctx.sparkContext.isStopped) {
|
||||
instantiatedContext.set(sqlContext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
|
||||
Option(instantiatedContext.get())
|
||||
}
|
||||
|
||||
private[sql] def clearSqlListener(): Unit = {
|
||||
sqlListener.set(null)
|
||||
SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -855,8 +771,9 @@ object SQLContext {
|
|||
*
|
||||
* @since 1.6.0
|
||||
*/
|
||||
@deprecated("Use SparkSession.setActiveSession instead", "2.0.0")
|
||||
def setActive(sqlContext: SQLContext): Unit = {
|
||||
activeContext.set(sqlContext)
|
||||
SparkSession.setActiveSession(sqlContext.sparkSession)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -865,12 +782,9 @@ object SQLContext {
|
|||
*
|
||||
* @since 1.6.0
|
||||
*/
|
||||
@deprecated("Use SparkSession.clearActiveSession instead", "2.0.0")
|
||||
def clearActive(): Unit = {
|
||||
activeContext.remove()
|
||||
}
|
||||
|
||||
private[sql] def getActive(): Option[SQLContext] = {
|
||||
Option(activeContext.get())
|
||||
SparkSession.clearActiveSession()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -894,20 +808,6 @@ object SQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
|
||||
*/
|
||||
private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = {
|
||||
if (sqlListener.get() == null) {
|
||||
val listener = new SQLListener(sc.conf)
|
||||
if (sqlListener.compareAndSet(null, listener)) {
|
||||
sc.addSparkListener(listener)
|
||||
sc.ui.foreach(new SQLTab(listener, _))
|
||||
}
|
||||
}
|
||||
sqlListener.get()
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
|
||||
*/
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.beans.Introspector
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.reflect.ClassTag
|
||||
|
@ -30,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
|
||||
import org.apache.spark.sql.catalog.Catalog
|
||||
import org.apache.spark.sql.catalyst._
|
||||
import org.apache.spark.sql.catalyst.catalog._
|
||||
|
@ -98,24 +100,10 @@ class SparkSession private(
|
|||
}
|
||||
|
||||
/**
|
||||
* A wrapped version of this session in the form of a [[SQLContext]].
|
||||
* A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility.
|
||||
*/
|
||||
@transient
|
||||
private var _wrapped: SQLContext = _
|
||||
|
||||
@transient
|
||||
private val _wrappedLock = new Object
|
||||
|
||||
protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized {
|
||||
if (_wrapped == null) {
|
||||
_wrapped = new SQLContext(self, isRootContext = false)
|
||||
}
|
||||
_wrapped
|
||||
}
|
||||
|
||||
protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized {
|
||||
_wrapped = sqlContext
|
||||
}
|
||||
private[sql] val sqlContext: SQLContext = new SQLContext(this)
|
||||
|
||||
protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
|
||||
protected[sql] def listener: SQLListener = sharedState.listener
|
||||
|
@ -238,7 +226,7 @@ class SparkSession private(
|
|||
*/
|
||||
@Experimental
|
||||
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
|
||||
SQLContext.setActive(wrapped)
|
||||
SparkSession.setActiveSession(this)
|
||||
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
|
||||
val attributeSeq = schema.toAttributes
|
||||
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
|
||||
|
@ -254,7 +242,7 @@ class SparkSession private(
|
|||
*/
|
||||
@Experimental
|
||||
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
|
||||
SQLContext.setActive(wrapped)
|
||||
SparkSession.setActiveSession(this)
|
||||
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
|
||||
val attributeSeq = schema.toAttributes
|
||||
Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
|
||||
|
@ -573,7 +561,7 @@ class SparkSession private(
|
|||
*/
|
||||
@Experimental
|
||||
object implicits extends SQLImplicits with Serializable {
|
||||
protected override def _sqlContext: SQLContext = wrapped
|
||||
protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
|
||||
}
|
||||
// scalastyle:on
|
||||
|
||||
|
@ -649,8 +637,16 @@ object SparkSession {
|
|||
|
||||
private[this] val options = new scala.collection.mutable.HashMap[String, String]
|
||||
|
||||
private[this] var userSuppliedContext: Option[SparkContext] = None
|
||||
|
||||
private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
|
||||
userSuppliedContext = Option(sparkContext)
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a name for the application, which will be shown in the Spark web UI.
|
||||
* If no application name is set, a randomly generated name will be used.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
|
@ -735,29 +731,130 @@ object SparkSession {
|
|||
}
|
||||
|
||||
/**
|
||||
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one
|
||||
* based on the options set in this builder.
|
||||
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
|
||||
* one based on the options set in this builder.
|
||||
*
|
||||
* This method first checks whether there is a valid thread-local SparkSession,
|
||||
* and if yes, return that one. It then checks whether there is a valid global
|
||||
* default SparkSession, and if yes, return that one. If no valid global default
|
||||
* SparkSession exists, the method creates a new SparkSession and assigns the
|
||||
* newly created SparkSession as the global default.
|
||||
*
|
||||
* In case an existing SparkSession is returned, the config options specified in
|
||||
* this builder will be applied to the existing SparkSession.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def getOrCreate(): SparkSession = synchronized {
|
||||
// Step 1. Create a SparkConf
|
||||
// Step 2. Get a SparkContext
|
||||
// Step 3. Get a SparkSession
|
||||
val sparkConf = new SparkConf()
|
||||
options.foreach { case (k, v) => sparkConf.set(k, v) }
|
||||
val sparkContext = SparkContext.getOrCreate(sparkConf)
|
||||
// Get the session from current thread's active session.
|
||||
var session = activeThreadSession.get()
|
||||
if ((session ne null) && !session.sparkContext.isStopped) {
|
||||
options.foreach { case (k, v) => session.conf.set(k, v) }
|
||||
return session
|
||||
}
|
||||
|
||||
SQLContext.getOrCreate(sparkContext).sparkSession
|
||||
// Global synchronization so we will only set the default session once.
|
||||
SparkSession.synchronized {
|
||||
// If the current thread does not have an active session, get it from the global session.
|
||||
session = defaultSession.get()
|
||||
if ((session ne null) && !session.sparkContext.isStopped) {
|
||||
options.foreach { case (k, v) => session.conf.set(k, v) }
|
||||
return session
|
||||
}
|
||||
|
||||
// No active nor global default session. Create a new one.
|
||||
val sparkContext = userSuppliedContext.getOrElse {
|
||||
// set app name if not given
|
||||
if (!options.contains("spark.app.name")) {
|
||||
options += "spark.app.name" -> java.util.UUID.randomUUID().toString
|
||||
}
|
||||
|
||||
val sparkConf = new SparkConf()
|
||||
options.foreach { case (k, v) => sparkConf.set(k, v) }
|
||||
SparkContext.getOrCreate(sparkConf)
|
||||
}
|
||||
session = new SparkSession(sparkContext)
|
||||
options.foreach { case (k, v) => session.conf.set(k, v) }
|
||||
defaultSession.set(session)
|
||||
|
||||
// Register a successfully instantiated context to the singleton. This should be at the
|
||||
// end of the class definition so that the singleton is updated only if there is no
|
||||
// exception in the construction of the instance.
|
||||
sparkContext.addSparkListener(new SparkListener {
|
||||
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
|
||||
defaultSession.set(null)
|
||||
sqlListener.set(null)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def builder(): Builder = new Builder
|
||||
|
||||
/**
|
||||
* Changes the SparkSession that will be returned in this thread and its children when
|
||||
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
|
||||
* a SparkSession with an isolated session, instead of the global (first created) context.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def setActiveSession(session: SparkSession): Unit = {
|
||||
activeThreadSession.set(session)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
|
||||
* return the first created context instead of a thread-local override.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def clearActiveSession(): Unit = {
|
||||
activeThreadSession.remove()
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the default SparkSession that is returned by the builder.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def setDefaultSession(session: SparkSession): Unit = {
|
||||
defaultSession.set(session)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the default SparkSession that is returned by the builder.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*/
|
||||
def clearDefaultSession(): Unit = {
|
||||
defaultSession.set(null)
|
||||
}
|
||||
|
||||
private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
|
||||
|
||||
private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
|
||||
|
||||
/** A global SQL listener used for the SQL UI. */
|
||||
private[sql] val sqlListener = new AtomicReference[SQLListener]()
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Private methods from now on
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The active SparkSession for the current thread. */
|
||||
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
|
||||
|
||||
/** Reference to the root SparkSession. */
|
||||
private val defaultSession = new AtomicReference[SparkSession]
|
||||
|
||||
private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState"
|
||||
private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
|
||||
|
||||
|
|
|
@ -157,7 +157,8 @@ private[sql] case class RowDataSourceScanExec(
|
|||
|
||||
val outputUnsafeRows = relation match {
|
||||
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
|
||||
!SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
|
||||
!SparkSession.getActiveSession.get.sessionState.conf.getConf(
|
||||
SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
|
||||
case _: HadoopFsRelation => true
|
||||
case _ => false
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
|
|||
}
|
||||
|
||||
lazy val analyzed: LogicalPlan = {
|
||||
SQLContext.setActive(sparkSession.wrapped)
|
||||
SparkSession.setActiveSession(sparkSession)
|
||||
sparkSession.sessionState.analyzer.execute(logical)
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
|
|||
lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
|
||||
|
||||
lazy val sparkPlan: SparkPlan = {
|
||||
SQLContext.setActive(sparkSession.wrapped)
|
||||
SparkSession.setActiveSession(sparkSession)
|
||||
planner.plan(ReturnAnswer(optimizedPlan)).next()
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv}
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.io.CompressionCodec
|
||||
import org.apache.spark.rdd.{RDD, RDDOperationScope}
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.{Row, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
|
@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
* populated by the query planning infrastructure.
|
||||
*/
|
||||
@transient
|
||||
protected[spark] final val sqlContext = SQLContext.getActive().orNull
|
||||
final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
|
||||
|
||||
protected def sparkContext = sqlContext.sparkContext
|
||||
|
||||
|
@ -65,7 +65,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
|
||||
/** Overridden make copy also propagates sqlContext to copied plan. */
|
||||
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
|
||||
SQLContext.setActive(sqlContext)
|
||||
SparkSession.setActiveSession(sqlContext.sparkSession)
|
||||
super.makeCopy(newArgs)
|
||||
}
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ case class DataSource(
|
|||
providingClass.newInstance() match {
|
||||
case s: StreamSourceProvider =>
|
||||
val (name, schema) = s.sourceSchema(
|
||||
sparkSession.wrapped, userSpecifiedSchema, className, options)
|
||||
sparkSession.sqlContext, userSpecifiedSchema, className, options)
|
||||
SourceInfo(name, schema)
|
||||
|
||||
case format: FileFormat =>
|
||||
|
@ -198,7 +198,8 @@ case class DataSource(
|
|||
def createSource(metadataPath: String): Source = {
|
||||
providingClass.newInstance() match {
|
||||
case s: StreamSourceProvider =>
|
||||
s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options)
|
||||
s.createSource(
|
||||
sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options)
|
||||
|
||||
case format: FileFormat =>
|
||||
val path = new CaseInsensitiveMap(options).getOrElse("path", {
|
||||
|
@ -215,7 +216,7 @@ case class DataSource(
|
|||
/** Returns a sink that can be used to continually write data. */
|
||||
def createSink(): Sink = {
|
||||
providingClass.newInstance() match {
|
||||
case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns)
|
||||
case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns)
|
||||
|
||||
case parquet: parquet.DefaultSource =>
|
||||
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
|
||||
|
@ -265,9 +266,9 @@ case class DataSource(
|
|||
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
|
||||
// TODO: Throw when too much is given.
|
||||
case (dataSource: SchemaRelationProvider, Some(schema)) =>
|
||||
dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema)
|
||||
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
|
||||
case (dataSource: RelationProvider, None) =>
|
||||
dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions)
|
||||
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
|
||||
case (_: SchemaRelationProvider, None) =>
|
||||
throw new AnalysisException(s"A schema needs to be specified when using $className.")
|
||||
case (_: RelationProvider, Some(_)) =>
|
||||
|
@ -383,7 +384,7 @@ case class DataSource(
|
|||
|
||||
providingClass.newInstance() match {
|
||||
case dataSource: CreatableRelationProvider =>
|
||||
dataSource.createRelation(sparkSession.wrapped, mode, options, data)
|
||||
dataSource.createRelation(sparkSession.sqlContext, mode, options, data)
|
||||
case format: FileFormat =>
|
||||
// Don't glob path for the write path. The contracts here are:
|
||||
// 1. Only one output path can be specified on the write path;
|
||||
|
|
|
@ -142,7 +142,7 @@ case class HadoopFsRelation(
|
|||
fileFormat: FileFormat,
|
||||
options: Map[String, String]) extends BaseRelation with FileRelation {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
val schema: StructType = {
|
||||
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
|
||||
|
|
|
@ -92,7 +92,7 @@ private[sql] case class JDBCRelation(
|
|||
with PrunedFilteredScan
|
||||
with InsertableRelation {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override val needConversion: Boolean = false
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ class StreamExecution(
|
|||
startLatch.countDown()
|
||||
|
||||
// While active, repeatedly attempt to run batches.
|
||||
SQLContext.setActive(sparkSession.wrapped)
|
||||
SparkSession.setActiveSession(sparkSession)
|
||||
|
||||
triggerExecutor.execute(() => {
|
||||
if (isActive) {
|
||||
|
|
|
@ -1168,7 +1168,7 @@ object functions {
|
|||
* @group normal_funcs
|
||||
*/
|
||||
def expr(expr: String): Column = {
|
||||
val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse {
|
||||
val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
|
||||
new SparkSqlParser(new SQLConf)
|
||||
}
|
||||
Column(parser.parseExpression(expr))
|
||||
|
|
|
@ -70,16 +70,6 @@ object SQLConf {
|
|||
.intConf
|
||||
.createWithDefault(10)
|
||||
|
||||
val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts")
|
||||
.doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
|
||||
"When set to false, only one SQLContext/HiveContext is allowed to be created " +
|
||||
"through the constructor (new SQLContexts/HiveContexts created through newSession " +
|
||||
"method is allowed). Please note that this conf needs to be set in Spark Conf. Once " +
|
||||
"a SQLContext/HiveContext has been created, changing the value of this conf will not " +
|
||||
"have effect.")
|
||||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed")
|
||||
.internal()
|
||||
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
|
||||
|
|
|
@ -18,10 +18,10 @@
|
|||
package org.apache.spark.sql.internal
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.{SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
|
||||
import org.apache.spark.sql.execution.CacheManager
|
||||
import org.apache.spark.sql.execution.ui.SQLListener
|
||||
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
|
||||
import org.apache.spark.util.MutableURLClassLoader
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
|
|||
/**
|
||||
* A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s.
|
||||
*/
|
||||
val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext)
|
||||
val listener: SQLListener = createListenerAndUI(sparkContext)
|
||||
|
||||
/**
|
||||
* A catalog that interacts with external systems.
|
||||
|
@ -51,6 +51,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
|
|||
val jarClassLoader = new NonClosableMutableURLClassLoader(
|
||||
org.apache.spark.util.Utils.getContextOrSparkClassLoader)
|
||||
|
||||
/**
|
||||
* Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
|
||||
*/
|
||||
private def createListenerAndUI(sc: SparkContext): SQLListener = {
|
||||
if (SparkSession.sqlListener.get() == null) {
|
||||
val listener = new SQLListener(sc.conf)
|
||||
if (SparkSession.sqlListener.compareAndSet(null, listener)) {
|
||||
sc.addSparkListener(listener)
|
||||
sc.ui.foreach(new SQLTab(listener, _))
|
||||
}
|
||||
}
|
||||
SparkSession.sqlListener.get()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
|
|||
|
||||
test("get all tables") {
|
||||
checkAnswer(
|
||||
spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
|
||||
spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
|
||||
Row("listtablessuitetable", true))
|
||||
|
||||
checkAnswer(
|
||||
|
@ -48,12 +48,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
|
|||
|
||||
spark.sessionState.catalog.dropTable(
|
||||
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
|
||||
assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
|
||||
assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
|
||||
}
|
||||
|
||||
test("getting all tables with a database name has no impact on returned table names") {
|
||||
checkAnswer(
|
||||
spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
|
||||
spark.sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
|
||||
Row("listtablessuitetable", true))
|
||||
|
||||
checkAnswer(
|
||||
|
@ -62,7 +62,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
|
|||
|
||||
spark.sessionState.catalog.dropTable(
|
||||
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
|
||||
assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
|
||||
assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
|
||||
}
|
||||
|
||||
test("query the returned DataFrame of tables") {
|
||||
|
@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
|
|||
StructField("tableName", StringType, false) ::
|
||||
StructField("isTemporary", BooleanType, false) :: Nil)
|
||||
|
||||
Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
|
||||
Seq(spark.sqlContext.tables(), sql("SHOW TABLes")).foreach {
|
||||
case tableDF =>
|
||||
assert(expectedSchema === tableDF.schema)
|
||||
|
||||
|
@ -81,7 +81,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
|
|||
Row(true, "listtablessuitetable")
|
||||
)
|
||||
checkAnswer(
|
||||
spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
|
||||
spark.sqlContext.tables()
|
||||
.filter("tableName = 'tables'").select("tableName", "isTemporary"),
|
||||
Row("tables", true))
|
||||
spark.catalog.dropTempView("tables")
|
||||
}
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
|
||||
|
||||
private var originalActiveSQLContext: Option[SQLContext] = _
|
||||
private var originalInstantiatedSQLContext: Option[SQLContext] = _
|
||||
private var sparkConf: SparkConf = _
|
||||
|
||||
override protected def beforeAll(): Unit = {
|
||||
originalActiveSQLContext = SQLContext.getActive()
|
||||
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
|
||||
|
||||
SQLContext.clearActive()
|
||||
SQLContext.clearInstantiatedContext()
|
||||
sparkConf =
|
||||
new SparkConf(false)
|
||||
.setMaster("local[*]")
|
||||
.setAppName("test")
|
||||
.set("spark.ui.enabled", "false")
|
||||
.set("spark.driver.allowMultipleContexts", "true")
|
||||
}
|
||||
|
||||
override protected def afterAll(): Unit = {
|
||||
// Set these states back.
|
||||
originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
|
||||
originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
|
||||
}
|
||||
|
||||
def testNewSession(rootSQLContext: SQLContext): Unit = {
|
||||
// Make sure we can successfully create new Session.
|
||||
rootSQLContext.newSession()
|
||||
|
||||
// Reset the state. It is always safe to clear the active context.
|
||||
SQLContext.clearActive()
|
||||
}
|
||||
|
||||
def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = {
|
||||
val conf =
|
||||
sparkConf
|
||||
.clone
|
||||
.set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString)
|
||||
val sparkContext = new SparkContext(conf)
|
||||
|
||||
try {
|
||||
if (allowsMultipleContexts) {
|
||||
new SQLContext(sparkContext)
|
||||
SQLContext.clearActive()
|
||||
} else {
|
||||
// If allowsMultipleContexts is false, make sure we can get the error.
|
||||
val message = intercept[SparkException] {
|
||||
new SQLContext(sparkContext)
|
||||
}.getMessage
|
||||
assert(message.contains("Only one SQLContext/HiveContext may be running"))
|
||||
}
|
||||
} finally {
|
||||
sparkContext.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("test the flag to disallow creating multiple root SQLContext") {
|
||||
Seq(false, true).foreach { allowMultipleSQLContexts =>
|
||||
val conf =
|
||||
sparkConf
|
||||
.clone
|
||||
.set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString)
|
||||
val sc = new SparkContext(conf)
|
||||
try {
|
||||
val rootSQLContext = new SQLContext(sc)
|
||||
testNewSession(rootSQLContext)
|
||||
testNewSession(rootSQLContext)
|
||||
testCreatingNewSQLContext(allowMultipleSQLContexts)
|
||||
} finally {
|
||||
sc.stop()
|
||||
SQLContext.clearInstantiatedContext()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,7 +40,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
|
|||
val newSession = sqlContext.newSession()
|
||||
assert(SQLContext.getOrCreate(sc).eq(sqlContext),
|
||||
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
|
||||
SQLContext.setActive(newSession)
|
||||
SparkSession.setActiveSession(newSession.sparkSession)
|
||||
assert(SQLContext.getOrCreate(sc).eq(newSession),
|
||||
"SQLContext.getOrCreate after explicitly setActive() did not return the active context")
|
||||
}
|
||||
|
|
|
@ -1042,7 +1042,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("SET commands semantics using sql()") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
val testKey = "test.key.0"
|
||||
val testVal = "test.val.0"
|
||||
val nonexistentKey = "nonexistent"
|
||||
|
@ -1083,17 +1083,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
sql(s"SET $nonexistentKey"),
|
||||
Row(nonexistentKey, "<undefined>")
|
||||
)
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
}
|
||||
|
||||
test("SET commands with illegal or inappropriate argument") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
// Set negative mapred.reduce.tasks for automatically determining
|
||||
// the number of reducers is not supported
|
||||
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
|
||||
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
|
||||
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
}
|
||||
|
||||
test("apply schema") {
|
||||
|
|
|
@ -25,6 +25,6 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext {
|
|||
|
||||
test("[SPARK-5235] SQLContext should be serializable") {
|
||||
val spark = SparkSession.builder.getOrCreate()
|
||||
new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped)
|
||||
new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkFunSuite}
|
||||
|
||||
/**
|
||||
* Test cases for the builder pattern of [[SparkSession]].
|
||||
*/
|
||||
class SparkSessionBuilderSuite extends SparkFunSuite {
|
||||
|
||||
private var initialSession: SparkSession = _
|
||||
|
||||
private lazy val sparkContext: SparkContext = {
|
||||
initialSession = SparkSession.builder()
|
||||
.master("local")
|
||||
.config("spark.ui.enabled", value = false)
|
||||
.config("some-config", "v2")
|
||||
.getOrCreate()
|
||||
initialSession.sparkContext
|
||||
}
|
||||
|
||||
test("create with config options and propagate them to SparkContext and SparkSession") {
|
||||
// Creating a new session with config - this works by just calling the lazy val
|
||||
sparkContext
|
||||
assert(initialSession.sparkContext.conf.get("some-config") == "v2")
|
||||
assert(initialSession.conf.get("some-config") == "v2")
|
||||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
test("use global default session") {
|
||||
val session = SparkSession.builder().getOrCreate()
|
||||
assert(SparkSession.builder().getOrCreate() == session)
|
||||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
test("config options are propagated to existing SparkSession") {
|
||||
val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate()
|
||||
assert(session1.conf.get("spark-config1") == "a")
|
||||
val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate()
|
||||
assert(session1 == session2)
|
||||
assert(session1.conf.get("spark-config1") == "b")
|
||||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
test("use session from active thread session and propagate config options") {
|
||||
val defaultSession = SparkSession.builder().getOrCreate()
|
||||
val activeSession = defaultSession.newSession()
|
||||
SparkSession.setActiveSession(activeSession)
|
||||
val session = SparkSession.builder().config("spark-config2", "a").getOrCreate()
|
||||
|
||||
assert(activeSession != defaultSession)
|
||||
assert(session == activeSession)
|
||||
assert(session.conf.get("spark-config2") == "a")
|
||||
SparkSession.clearActiveSession()
|
||||
|
||||
assert(SparkSession.builder().getOrCreate() == defaultSession)
|
||||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
test("create a new session if the default session has been stopped") {
|
||||
val defaultSession = SparkSession.builder().getOrCreate()
|
||||
SparkSession.setDefaultSession(defaultSession)
|
||||
defaultSession.stop()
|
||||
val newSession = SparkSession.builder().master("local").getOrCreate()
|
||||
assert(newSession != defaultSession)
|
||||
newSession.stop()
|
||||
}
|
||||
|
||||
test("create a new session if the active thread session has been stopped") {
|
||||
val activeSession = SparkSession.builder().master("local").getOrCreate()
|
||||
SparkSession.setActiveSession(activeSession)
|
||||
activeSession.stop()
|
||||
val newSession = SparkSession.builder().master("local").getOrCreate()
|
||||
assert(newSession != activeSession)
|
||||
newSession.stop()
|
||||
}
|
||||
}
|
|
@ -26,9 +26,9 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
|
|||
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
|
||||
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
|
||||
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
|
||||
spark.wrapped.conf.autoBroadcastJoinThreshold)
|
||||
spark.sessionState.conf.autoBroadcastJoinThreshold)
|
||||
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
|
||||
spark.wrapped.conf.autoBroadcastJoinThreshold)
|
||||
spark.sessionState.conf.autoBroadcastJoinThreshold)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -27,21 +27,21 @@ import org.apache.spark.sql.internal.SQLConf
|
|||
|
||||
class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
|
||||
|
||||
private var originalActiveSQLContext: Option[SQLContext] = _
|
||||
private var originalInstantiatedSQLContext: Option[SQLContext] = _
|
||||
private var originalActiveSQLContext: Option[SparkSession] = _
|
||||
private var originalInstantiatedSQLContext: Option[SparkSession] = _
|
||||
|
||||
override protected def beforeAll(): Unit = {
|
||||
originalActiveSQLContext = SQLContext.getActive()
|
||||
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
|
||||
originalActiveSQLContext = SparkSession.getActiveSession
|
||||
originalInstantiatedSQLContext = SparkSession.getDefaultSession
|
||||
|
||||
SQLContext.clearActive()
|
||||
SQLContext.clearInstantiatedContext()
|
||||
SparkSession.clearActiveSession()
|
||||
SparkSession.clearDefaultSession()
|
||||
}
|
||||
|
||||
override protected def afterAll(): Unit = {
|
||||
// Set these states back.
|
||||
originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
|
||||
originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
|
||||
originalActiveSQLContext.foreach(ctx => SparkSession.setActiveSession(ctx))
|
||||
originalInstantiatedSQLContext.foreach(ctx => SparkSession.setDefaultSession(ctx))
|
||||
}
|
||||
|
||||
private def checkEstimation(
|
||||
|
|
|
@ -155,7 +155,7 @@ class PlannerSuite extends SharedSQLContext {
|
|||
val path = file.getCanonicalPath
|
||||
testData.write.parquet(path)
|
||||
val df = spark.read.parquet(path)
|
||||
spark.wrapped.registerDataFrameAsTable(df, "testPushed")
|
||||
spark.sqlContext.registerDataFrameAsTable(df, "testPushed")
|
||||
|
||||
withTempTable("testPushed") {
|
||||
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
|
||||
|
|
|
@ -91,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
|
|||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
SparkPlanTest
|
||||
.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match {
|
||||
.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match {
|
||||
case Some(errorMessage) => fail(errorMessage)
|
||||
case None =>
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
|
|||
expectedPlanFunction: SparkPlan => SparkPlan,
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
SparkPlanTest.checkAnswer(
|
||||
input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match {
|
||||
input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match {
|
||||
case Some(errorMessage) => fail(errorMessage)
|
||||
case None =>
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
|
|||
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
|
||||
(f: => Unit): Unit = {
|
||||
withParquetDataFrame(data, testVectorized) { df =>
|
||||
spark.wrapped.registerDataFrameAsTable(df, tableName)
|
||||
spark.sqlContext.registerDataFrameAsTable(df, tableName)
|
||||
withTempTable(tableName)(f)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
val opId = 0
|
||||
val rdd1 =
|
||||
makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
|
||||
spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(
|
||||
spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
|
||||
increment)
|
||||
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
|
||||
|
||||
// Generate next version of stores
|
||||
val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
|
||||
spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
|
||||
spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
|
||||
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
|
||||
|
||||
// Make sure the previous RDD still has the same data.
|
||||
|
@ -82,7 +82,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
spark: SparkSession,
|
||||
seq: Seq[String],
|
||||
storeVersion: Int): RDD[(String, Int)] = {
|
||||
implicit val sqlContext = spark.wrapped
|
||||
implicit val sqlContext = spark.sqlContext
|
||||
makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
|
||||
sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
|
||||
test("usage with iterators - only gets and only puts") {
|
||||
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
|
||||
implicit val sqlContext = spark.wrapped
|
||||
implicit val sqlContext = spark.sqlContext
|
||||
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
|
||||
val opId = 0
|
||||
|
||||
|
@ -131,7 +131,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
}
|
||||
|
||||
val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
|
||||
spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
|
||||
spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
|
||||
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
|
||||
|
||||
val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
|
||||
|
@ -150,7 +150,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
|
||||
|
||||
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
|
||||
implicit val sqlContext = spark.wrapped
|
||||
implicit val sqlContext = spark.sqlContext
|
||||
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
|
||||
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
|
||||
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
|
||||
|
@ -183,7 +183,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
|
|||
SparkSession.builder
|
||||
.config(sparkConf.setMaster("local-cluster[2, 1, 1024]"))
|
||||
.getOrCreate()) { spark =>
|
||||
implicit val sqlContext = spark.wrapped
|
||||
implicit val sqlContext = spark.sqlContext
|
||||
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
|
||||
val opId = 0
|
||||
val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.mockito.Mockito.mock
|
|||
import org.apache.spark._
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.util.quietly
|
||||
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
@ -400,8 +400,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
|
|||
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
|
||||
val sc = new SparkContext(conf)
|
||||
try {
|
||||
SQLContext.clearSqlListener()
|
||||
val spark = new SQLContext(sc)
|
||||
SparkSession.sqlListener.set(null)
|
||||
val spark = new SparkSession(sc)
|
||||
import spark.implicits._
|
||||
// Run 100 successful executions and 100 failed executions.
|
||||
// Each execution only has one job and one stage.
|
||||
|
|
|
@ -35,7 +35,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
// Set a conf first.
|
||||
spark.conf.set(testKey, testVal)
|
||||
// Clear the conf.
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
// After clear, only overrideConfs used by unit test should be in the SQLConf.
|
||||
assert(spark.conf.getAll === TestSQLContext.overrideConfs)
|
||||
|
||||
|
@ -50,11 +50,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
assert(spark.conf.get(testKey, testVal + "_") === testVal)
|
||||
assert(spark.conf.getAll.contains(testKey))
|
||||
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
}
|
||||
|
||||
test("parse SQL set commands") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
sql(s"set $testKey=$testVal")
|
||||
assert(spark.conf.get(testKey, testVal + "_") === testVal)
|
||||
assert(spark.conf.get(testKey, testVal + "_") === testVal)
|
||||
|
@ -72,11 +72,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
sql(s"set $key=")
|
||||
assert(spark.conf.get(key, "0") === "")
|
||||
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
}
|
||||
|
||||
test("set command for display") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sessionState.conf.clear()
|
||||
checkAnswer(
|
||||
sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"),
|
||||
Nil)
|
||||
|
@ -97,7 +97,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("deprecated property") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
|
||||
try{
|
||||
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
|
||||
|
@ -108,7 +108,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("invalid conf value") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
val e = intercept[IllegalArgumentException] {
|
||||
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") {
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
|
||||
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
|
||||
assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100)
|
||||
|
@ -144,7 +144,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
|
||||
}
|
||||
|
||||
spark.wrapped.conf.clear()
|
||||
spark.sqlContext.conf.clear()
|
||||
}
|
||||
|
||||
test("SparkSession can access configs set in SparkConf") {
|
||||
|
|
|
@ -41,7 +41,7 @@ case class SimpleDDLScan(
|
|||
table: String)(@transient val sparkSession: SparkSession)
|
||||
extends BaseRelation with TableScan {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(Seq(
|
||||
|
|
|
@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S
|
|||
extends BaseRelation
|
||||
with PrunedFilteredScan {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(
|
||||
|
|
|
@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa
|
|||
extends BaseRelation
|
||||
with PrunedScan {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(
|
||||
|
|
|
@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider {
|
|||
case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
|
||||
extends BaseRelation with TableScan {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override def schema: StructType =
|
||||
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
|
||||
|
@ -70,7 +70,7 @@ case class AllDataTypesScan(
|
|||
extends BaseRelation
|
||||
with TableScan {
|
||||
|
||||
override def sqlContext: SQLContext = sparkSession.wrapped
|
||||
override def sqlContext: SQLContext = sparkSession.sqlContext
|
||||
|
||||
override def schema: StructType = userSpecifiedSchema
|
||||
|
||||
|
|
|
@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
|
|||
q.stop()
|
||||
|
||||
verify(LastOptions.mockStreamSourceProvider).createSource(
|
||||
spark.wrapped,
|
||||
spark.sqlContext,
|
||||
checkpointLocation + "/sources/0",
|
||||
None,
|
||||
"org.apache.spark.sql.streaming.test",
|
||||
Map.empty)
|
||||
|
||||
verify(LastOptions.mockStreamSourceProvider).createSource(
|
||||
spark.wrapped,
|
||||
spark.sqlContext,
|
||||
checkpointLocation + "/sources/1",
|
||||
None,
|
||||
"org.apache.spark.sql.streaming.test",
|
||||
|
|
|
@ -30,7 +30,7 @@ private[sql] trait SQLTestData { self =>
|
|||
|
||||
// Helper object to import SQL implicits without a concrete SQLContext
|
||||
private object internalImplicits extends SQLImplicits {
|
||||
protected override def _sqlContext: SQLContext = self.spark.wrapped
|
||||
protected override def _sqlContext: SQLContext = self.spark.sqlContext
|
||||
}
|
||||
|
||||
import internalImplicits._
|
||||
|
|
|
@ -66,7 +66,7 @@ private[sql] trait SQLTestUtils
|
|||
* but the implicits import is needed in the constructor.
|
||||
*/
|
||||
protected object testImplicits extends SQLImplicits {
|
||||
protected override def _sqlContext: SQLContext = self.spark.wrapped
|
||||
protected override def _sqlContext: SQLContext = self.spark.sqlContext
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -44,13 +44,13 @@ trait SharedSQLContext extends SQLTestUtils {
|
|||
/**
|
||||
* The [[TestSQLContext]] to use for all tests in this suite.
|
||||
*/
|
||||
protected implicit def sqlContext: SQLContext = _spark.wrapped
|
||||
protected implicit def sqlContext: SQLContext = _spark.sqlContext
|
||||
|
||||
/**
|
||||
* Initialize the [[TestSparkSession]].
|
||||
*/
|
||||
protected override def beforeAll(): Unit = {
|
||||
SQLContext.clearSqlListener()
|
||||
SparkSession.sqlListener.set(null)
|
||||
if (_spark == null) {
|
||||
_spark = new TestSparkSession(sparkConf)
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging {
|
|||
|
||||
val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate()
|
||||
sparkContext = sparkSession.sparkContext
|
||||
sqlContext = sparkSession.wrapped
|
||||
sqlContext = sparkSession.sqlContext
|
||||
|
||||
val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
|
||||
sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8"))
|
||||
|
|
|
@ -30,7 +30,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
|
|||
|
||||
override protected def beforeEach(): Unit = {
|
||||
super.beforeEach()
|
||||
if (spark.wrapped.tableNames().contains("src")) {
|
||||
if (spark.sqlContext.tableNames().contains("src")) {
|
||||
spark.catalog.dropTempView("src")
|
||||
}
|
||||
Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src")
|
||||
|
|
|
@ -36,11 +36,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
withTempDatabase { db =>
|
||||
activateDatabase(db) {
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
|
||||
assert(spark.wrapped.tableNames().contains("t"))
|
||||
assert(spark.sqlContext.tableNames().contains("t"))
|
||||
checkAnswer(spark.table("t"), df)
|
||||
}
|
||||
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table(s"$db.t"), df)
|
||||
|
||||
checkTablePath(db, "t")
|
||||
|
@ -50,7 +50,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
test(s"saveAsTable() to non-default database - without USE - Overwrite") {
|
||||
withTempDatabase { db =>
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table(s"$db.t"), df)
|
||||
|
||||
checkTablePath(db, "t")
|
||||
|
@ -65,7 +65,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
|
||||
|
||||
spark.catalog.createExternalTable("t", path, "parquet")
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table("t"), df)
|
||||
|
||||
sql(
|
||||
|
@ -76,7 +76,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
| path '$path'
|
||||
|)
|
||||
""".stripMargin)
|
||||
assert(spark.wrapped.tableNames(db).contains("t1"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t1"))
|
||||
checkAnswer(spark.table("t1"), df)
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
|
||||
spark.catalog.createExternalTable(s"$db.t", path, "parquet")
|
||||
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table(s"$db.t"), df)
|
||||
|
||||
sql(
|
||||
|
@ -101,7 +101,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
| path '$path'
|
||||
|)
|
||||
""".stripMargin)
|
||||
assert(spark.wrapped.tableNames(db).contains("t1"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t1"))
|
||||
checkAnswer(spark.table(s"$db.t1"), df)
|
||||
}
|
||||
}
|
||||
|
@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
activateDatabase(db) {
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
|
||||
df.write.mode(SaveMode.Append).saveAsTable("t")
|
||||
assert(spark.wrapped.tableNames().contains("t"))
|
||||
assert(spark.sqlContext.tableNames().contains("t"))
|
||||
checkAnswer(spark.table("t"), df.union(df))
|
||||
}
|
||||
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table(s"$db.t"), df.union(df))
|
||||
|
||||
checkTablePath(db, "t")
|
||||
|
@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
withTempDatabase { db =>
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
|
||||
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
checkAnswer(spark.table(s"$db.t"), df.union(df))
|
||||
|
||||
checkTablePath(db, "t")
|
||||
|
@ -138,7 +138,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
withTempDatabase { db =>
|
||||
activateDatabase(db) {
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
|
||||
assert(spark.wrapped.tableNames().contains("t"))
|
||||
assert(spark.sqlContext.tableNames().contains("t"))
|
||||
|
||||
df.write.insertInto(s"$db.t")
|
||||
checkAnswer(spark.table(s"$db.t"), df.union(df))
|
||||
|
@ -150,10 +150,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
withTempDatabase { db =>
|
||||
activateDatabase(db) {
|
||||
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
|
||||
assert(spark.wrapped.tableNames().contains("t"))
|
||||
assert(spark.sqlContext.tableNames().contains("t"))
|
||||
}
|
||||
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
|
||||
df.write.insertInto(s"$db.t")
|
||||
checkAnswer(spark.table(s"$db.t"), df.union(df))
|
||||
|
@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
|
|||
withTempDatabase { db =>
|
||||
activateDatabase(db) {
|
||||
sql(s"CREATE TABLE t (key INT)")
|
||||
assert(spark.wrapped.tableNames().contains("t"))
|
||||
assert(!spark.wrapped.tableNames("default").contains("t"))
|
||||
assert(spark.sqlContext.tableNames().contains("t"))
|
||||
assert(!spark.sqlContext.tableNames("default").contains("t"))
|
||||
}
|
||||
|
||||
assert(!spark.wrapped.tableNames().contains("t"))
|
||||
assert(spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(!spark.sqlContext.tableNames().contains("t"))
|
||||
assert(spark.sqlContext.tableNames(db).contains("t"))
|
||||
|
||||
activateDatabase(db) {
|
||||
sql(s"DROP TABLE t")
|
||||
assert(!spark.wrapped.tableNames().contains("t"))
|
||||
assert(!spark.wrapped.tableNames("default").contains("t"))
|
||||
assert(!spark.sqlContext.tableNames().contains("t"))
|
||||
assert(!spark.sqlContext.tableNames("default").contains("t"))
|
||||
}
|
||||
|
||||
assert(!spark.wrapped.tableNames().contains("t"))
|
||||
assert(!spark.wrapped.tableNames(db).contains("t"))
|
||||
assert(!spark.sqlContext.tableNames().contains("t"))
|
||||
assert(!spark.sqlContext.tableNames(db).contains("t"))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1417,7 +1417,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
|
|||
""".stripMargin)
|
||||
|
||||
checkAnswer(
|
||||
spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"),
|
||||
spark.sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
|
||||
Row(true)
|
||||
)
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
|
|||
(data: Seq[T], tableName: String)
|
||||
(f: => Unit): Unit = {
|
||||
withOrcDataFrame(data) { df =>
|
||||
spark.wrapped.registerDataFrameAsTable(df, tableName)
|
||||
spark.sqlContext.registerDataFrameAsTable(df, tableName)
|
||||
withTempTable(tableName)(f)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue