[SPARK-15075][SPARK-15345][SQL] Clean up SparkSession builder and propagate config options to existing sessions if specified

## What changes were proposed in this pull request?
Currently SparkSession.Builder use SQLContext.getOrCreate. It should probably the the other way around, i.e. all the core logic goes in SparkSession, and SQLContext just calls that. This patch does that.

This patch also makes sure config options specified in the builder are propagated to the existing (and of course the new) SparkSession.

## How was this patch tested?
Updated tests to reflect the change, and also introduced a new SparkSessionBuilderSuite that should cover all the branches.

Author: Reynold Xin <rxin@databricks.com>

Closes #13200 from rxin/SPARK-15075.
This commit is contained in:
Reynold Xin 2016-05-19 21:53:26 -07:00
parent 17591d90e6
commit f2ee0ed4b7
43 changed files with 366 additions and 356 deletions

View file

@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
} catch (IOException e) {
// expected
}
instance.write().context(spark.wrapped()).overwrite().save(outputPath);
instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",

View file

@ -34,7 +34,10 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
class SQLContext(object):
"""Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
"""The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
here for backward compatibility.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.

View file

@ -120,6 +120,8 @@ class SparkSession(object):
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
If no application name is set, a randomly generated name will be used.
:param name: an application name
"""
return self.config("spark.app.name", name)
@ -133,8 +135,17 @@ class SparkSession(object):
@since(2.0)
def getOrCreate(self):
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
one based on the options set in this builder.
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
new one based on the options set in this builder.
This method first checks whether there is a valid thread-local SparkSession,
and if yes, return that one. It then checks whether there is a valid global
default SparkSession, and if yes, return that one. If no valid global default
SparkSession exists, the method creates a new SparkSession and assigns the
newly created SparkSession as the global default.
In case an existing SparkSession is returned, the config options specified
in this builder will be applied to the existing SparkSession.
"""
with self._lock:
from pyspark.conf import SparkConf
@ -175,7 +186,7 @@ class SparkSession(object):
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.wrapped()
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()

View file

@ -213,7 +213,7 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.wrapped
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)

View file

@ -19,25 +19,22 @@ package org.apache.spark.sql
import java.beans.BeanInfo
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager
/**
* The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
*
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here
* for backward compatibility.
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class
* here for backward compatibility.
*
* @groupname basic Basic Operations
* @groupname ddl_ops Persistent Catalog DDL
@ -76,42 +73,21 @@ class SQLContext private[sql](
this(sparkSession, true)
}
@deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sc: SparkContext) = {
this(new SparkSession(sc))
}
@deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
// TODO: move this logic into SparkSession
// If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
// wants to create a new root SQLContext (a SQLContext that is not created by newSession).
private val allowMultipleContexts =
sparkContext.conf.getBoolean(
SQLConf.ALLOW_MULTIPLE_CONTEXTS.key,
SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get)
// Assert no root SQLContext is running when allowMultipleContexts is false.
{
if (!allowMultipleContexts && isRootContext) {
SQLContext.getInstantiatedContextOption() match {
case Some(rootSQLContext) =>
val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " +
s"It is recommended to use SQLContext.getOrCreate to get the instantiated " +
s"SQLContext/HiveContext. To ignore this error, " +
s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf."
throw new SparkException(errMsg)
case None => // OK
}
}
}
protected[sql] def sessionState: SessionState = sparkSession.sessionState
protected[sql] def sharedState: SharedState = sparkSession.sharedState
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf
protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
protected[sql] def listener: SQLListener = sparkSession.listener
protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog
def sparkContext: SparkContext = sparkSession.sparkContext
@ -123,7 +99,7 @@ class SQLContext private[sql](
*
* @since 1.6.0
*/
def newSession(): SQLContext = sparkSession.newSession().wrapped
def newSession(): SQLContext = sparkSession.newSession().sqlContext
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@ -760,21 +736,6 @@ class SQLContext private[sql](
schema: StructType): DataFrame = {
sparkSession.applySchemaToPythonRDD(rdd, schema)
}
// TODO: move this logic into SparkSession
// Register a successfully instantiated context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
// construction of the instance.
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
SQLContext.clearInstantiatedContext()
SQLContext.clearSqlListener()
}
})
sparkSession.setWrappedContext(self)
SQLContext.setInstantiatedContext(self)
}
/**
@ -787,19 +748,6 @@ class SQLContext private[sql](
*/
object SQLContext {
/**
* The active SQLContext for the current thread.
*/
private val activeContext: InheritableThreadLocal[SQLContext] =
new InheritableThreadLocal[SQLContext]
/**
* Reference to the created SQLContext.
*/
@transient private val instantiatedContext = new AtomicReference[SQLContext]()
@transient private val sqlListener = new AtomicReference[SQLListener]()
/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
*
@ -811,41 +759,9 @@ object SQLContext {
*
* @since 1.5.0
*/
@deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContext): SQLContext = {
val ctx = activeContext.get()
if (ctx != null && !ctx.sparkContext.isStopped) {
return ctx
}
synchronized {
val ctx = instantiatedContext.get()
if (ctx == null || ctx.sparkContext.isStopped) {
new SQLContext(sparkContext)
} else {
ctx
}
}
}
private[sql] def clearInstantiatedContext(): Unit = {
instantiatedContext.set(null)
}
private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
synchronized {
val ctx = instantiatedContext.get()
if (ctx == null || ctx.sparkContext.isStopped) {
instantiatedContext.set(sqlContext)
}
}
}
private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
Option(instantiatedContext.get())
}
private[sql] def clearSqlListener(): Unit = {
sqlListener.set(null)
SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
}
/**
@ -855,8 +771,9 @@ object SQLContext {
*
* @since 1.6.0
*/
@deprecated("Use SparkSession.setActiveSession instead", "2.0.0")
def setActive(sqlContext: SQLContext): Unit = {
activeContext.set(sqlContext)
SparkSession.setActiveSession(sqlContext.sparkSession)
}
/**
@ -865,12 +782,9 @@ object SQLContext {
*
* @since 1.6.0
*/
@deprecated("Use SparkSession.clearActiveSession instead", "2.0.0")
def clearActive(): Unit = {
activeContext.remove()
}
private[sql] def getActive(): Option[SQLContext] = {
Option(activeContext.get())
SparkSession.clearActiveSession()
}
/**
@ -894,20 +808,6 @@ object SQLContext {
}
}
/**
* Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
*/
private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = {
if (sqlListener.get() == null) {
val listener = new SQLListener(sc.conf)
if (sqlListener.compareAndSet(null, listener)) {
sc.addSparkListener(listener)
sc.ui.foreach(new SQLTab(listener, _))
}
}
sqlListener.get()
}
/**
* Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
*/

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.beans.Introspector
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
@ -30,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
@ -98,24 +100,10 @@ class SparkSession private(
}
/**
* A wrapped version of this session in the form of a [[SQLContext]].
* A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility.
*/
@transient
private var _wrapped: SQLContext = _
@transient
private val _wrappedLock = new Object
protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized {
if (_wrapped == null) {
_wrapped = new SQLContext(self, isRootContext = false)
}
_wrapped
}
protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized {
_wrapped = sqlContext
}
private[sql] val sqlContext: SQLContext = new SQLContext(this)
protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
protected[sql] def listener: SQLListener = sharedState.listener
@ -238,7 +226,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SQLContext.setActive(wrapped)
SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
@ -254,7 +242,7 @@ class SparkSession private(
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
SQLContext.setActive(wrapped)
SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
@ -573,7 +561,7 @@ class SparkSession private(
*/
@Experimental
object implicits extends SQLImplicits with Serializable {
protected override def _sqlContext: SQLContext = wrapped
protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
}
// scalastyle:on
@ -649,8 +637,16 @@ object SparkSession {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
private[this] var userSuppliedContext: Option[SparkContext] = None
private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
userSuppliedContext = Option(sparkContext)
this
}
/**
* Sets a name for the application, which will be shown in the Spark web UI.
* If no application name is set, a randomly generated name will be used.
*
* @since 2.0.0
*/
@ -735,29 +731,130 @@ object SparkSession {
}
/**
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one
* based on the options set in this builder.
* Gets an existing [[SparkSession]] or, if there is no existing one, creates a new
* one based on the options set in this builder.
*
* This method first checks whether there is a valid thread-local SparkSession,
* and if yes, return that one. It then checks whether there is a valid global
* default SparkSession, and if yes, return that one. If no valid global default
* SparkSession exists, the method creates a new SparkSession and assigns the
* newly created SparkSession as the global default.
*
* In case an existing SparkSession is returned, the config options specified in
* this builder will be applied to the existing SparkSession.
*
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
// Step 1. Create a SparkConf
// Step 2. Get a SparkContext
// Step 3. Get a SparkSession
val sparkConf = new SparkConf()
options.foreach { case (k, v) => sparkConf.set(k, v) }
val sparkContext = SparkContext.getOrCreate(sparkConf)
// Get the session from current thread's active session.
var session = activeThreadSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
options.foreach { case (k, v) => session.conf.set(k, v) }
return session
}
SQLContext.getOrCreate(sparkContext).sparkSession
// Global synchronization so we will only set the default session once.
SparkSession.synchronized {
// If the current thread does not have an active session, get it from the global session.
session = defaultSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
options.foreach { case (k, v) => session.conf.set(k, v) }
return session
}
// No active nor global default session. Create a new one.
val sparkContext = userSuppliedContext.getOrElse {
// set app name if not given
if (!options.contains("spark.app.name")) {
options += "spark.app.name" -> java.util.UUID.randomUUID().toString
}
val sparkConf = new SparkConf()
options.foreach { case (k, v) => sparkConf.set(k, v) }
SparkContext.getOrCreate(sparkConf)
}
session = new SparkSession(sparkContext)
options.foreach { case (k, v) => session.conf.set(k, v) }
defaultSession.set(session)
// Register a successfully instantiated context to the singleton. This should be at the
// end of the class definition so that the singleton is updated only if there is no
// exception in the construction of the instance.
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
defaultSession.set(null)
sqlListener.set(null)
}
})
}
return session
}
}
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
* @since 2.0.0
*/
def builder(): Builder = new Builder
/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* a SparkSession with an isolated session, instead of the global (first created) context.
*
* @since 2.0.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/**
* Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
* return the first created context instead of a thread-local override.
*
* @since 2.0.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/**
* Sets the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}
/**
* Clears the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}
private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
/** A global SQL listener used for the SQL UI. */
private[sql] val sqlListener = new AtomicReference[SQLListener]()
////////////////////////////////////////////////////////////////////////////////////////
// Private methods from now on
////////////////////////////////////////////////////////////////////////////////////////
/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]
private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState"
private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"

View file

@ -157,7 +157,8 @@ private[sql] case class RowDataSourceScanExec(
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
!SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
!SparkSession.getActiveSession.get.sessionState.conf.getConf(
SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
case _: HadoopFsRelation => true
case _ => false
}

View file

@ -60,7 +60,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
lazy val analyzed: LogicalPlan = {
SQLContext.setActive(sparkSession.wrapped)
SparkSession.setActiveSession(sparkSession)
sparkSession.sessionState.analyzer.execute(logical)
}
@ -73,7 +73,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sparkSession.wrapped)
SparkSession.setActiveSession(sparkSession)
planner.plan(ReturnAnswer(optimizedPlan)).next()
}

View file

@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.{Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
protected[spark] final val sqlContext = SQLContext.getActive().orNull
final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
protected def sparkContext = sqlContext.sparkContext
@ -65,7 +65,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SQLContext.setActive(sqlContext)
SparkSession.setActiveSession(sqlContext.sparkSession)
super.makeCopy(newArgs)
}

View file

@ -178,7 +178,7 @@ case class DataSource(
providingClass.newInstance() match {
case s: StreamSourceProvider =>
val (name, schema) = s.sourceSchema(
sparkSession.wrapped, userSpecifiedSchema, className, options)
sparkSession.sqlContext, userSpecifiedSchema, className, options)
SourceInfo(name, schema)
case format: FileFormat =>
@ -198,7 +198,8 @@ case class DataSource(
def createSource(metadataPath: String): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options)
s.createSource(
sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options)
case format: FileFormat =>
val path = new CaseInsensitiveMap(options).getOrElse("path", {
@ -215,7 +216,7 @@ case class DataSource(
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
providingClass.newInstance() match {
case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns)
case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns)
case parquet: parquet.DefaultSource =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
@ -265,9 +266,9 @@ case class DataSource(
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema)
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
case (dataSource: RelationProvider, None) =>
dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions)
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (_: RelationProvider, Some(_)) =>
@ -383,7 +384,7 @@ case class DataSource(
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sparkSession.wrapped, mode, options, data)
dataSource.createRelation(sparkSession.sqlContext, mode, options, data)
case format: FileFormat =>
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;

View file

@ -142,7 +142,7 @@ case class HadoopFsRelation(
fileFormat: FileFormat,
options: Map[String, String]) extends BaseRelation with FileRelation {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet

View file

@ -92,7 +92,7 @@ private[sql] case class JDBCRelation(
with PrunedFilteredScan
with InsertableRelation {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override val needConversion: Boolean = false

View file

@ -173,7 +173,7 @@ class StreamExecution(
startLatch.countDown()
// While active, repeatedly attempt to run batches.
SQLContext.setActive(sparkSession.wrapped)
SparkSession.setActiveSession(sparkSession)
triggerExecutor.execute(() => {
if (isActive) {

View file

@ -1168,7 +1168,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse {
val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser(new SQLConf)
}
Column(parser.parseExpression(expr))

View file

@ -70,16 +70,6 @@ object SQLConf {
.intConf
.createWithDefault(10)
val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts")
.doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
"When set to false, only one SQLContext/HiveContext is allowed to be created " +
"through the constructor (new SQLContexts/HiveContexts created through newSession " +
"method is allowed). Please note that this conf needs to be set in Spark Conf. Once " +
"a SQLContext/HiveContext has been created, changing the value of this conf will not " +
"have effect.")
.booleanConf
.createWithDefault(true)
val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed")
.internal()
.doc("When set to true Spark SQL will automatically select a compression codec for each " +

View file

@ -18,10 +18,10 @@
package org.apache.spark.sql.internal
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.util.MutableURLClassLoader
@ -38,7 +38,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
/**
* A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s.
*/
val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext)
val listener: SQLListener = createListenerAndUI(sparkContext)
/**
* A catalog that interacts with external systems.
@ -51,6 +51,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) {
val jarClassLoader = new NonClosableMutableURLClassLoader(
org.apache.spark.util.Utils.getContextOrSparkClassLoader)
/**
* Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
*/
private def createListenerAndUI(sc: SparkContext): SQLListener = {
if (SparkSession.sqlListener.get() == null) {
val listener = new SQLListener(sc.conf)
if (SparkSession.sqlListener.compareAndSet(null, listener)) {
sc.addSparkListener(listener)
sc.ui.foreach(new SQLTab(listener, _))
}
}
SparkSession.sqlListener.get()
}
}

View file

@ -39,7 +39,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
test("get all tables") {
checkAnswer(
spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"),
spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@ -48,12 +48,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("getting all tables with a database name has no impact on returned table names") {
checkAnswer(
spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"),
spark.sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"),
Row("listtablessuitetable", true))
checkAnswer(
@ -62,7 +62,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
spark.sessionState.catalog.dropTable(
TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true)
assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0)
}
test("query the returned DataFrame of tables") {
@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach {
Seq(spark.sqlContext.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
@ -81,7 +81,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
Row(true, "listtablessuitetable")
)
checkAnswer(
spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
spark.sqlContext.tables()
.filter("tableName = 'tables'").select("tableName", "isTemporary"),
Row("tables", true))
spark.catalog.dropTempView("tables")
}

View file

@ -1,100 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
import org.apache.spark.sql.internal.SQLConf
class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
private var originalActiveSQLContext: Option[SQLContext] = _
private var originalInstantiatedSQLContext: Option[SQLContext] = _
private var sparkConf: SparkConf = _
override protected def beforeAll(): Unit = {
originalActiveSQLContext = SQLContext.getActive()
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
SQLContext.clearActive()
SQLContext.clearInstantiatedContext()
sparkConf =
new SparkConf(false)
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
.set("spark.driver.allowMultipleContexts", "true")
}
override protected def afterAll(): Unit = {
// Set these states back.
originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
}
def testNewSession(rootSQLContext: SQLContext): Unit = {
// Make sure we can successfully create new Session.
rootSQLContext.newSession()
// Reset the state. It is always safe to clear the active context.
SQLContext.clearActive()
}
def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = {
val conf =
sparkConf
.clone
.set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString)
val sparkContext = new SparkContext(conf)
try {
if (allowsMultipleContexts) {
new SQLContext(sparkContext)
SQLContext.clearActive()
} else {
// If allowsMultipleContexts is false, make sure we can get the error.
val message = intercept[SparkException] {
new SQLContext(sparkContext)
}.getMessage
assert(message.contains("Only one SQLContext/HiveContext may be running"))
}
} finally {
sparkContext.stop()
}
}
test("test the flag to disallow creating multiple root SQLContext") {
Seq(false, true).foreach { allowMultipleSQLContexts =>
val conf =
sparkConf
.clone
.set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString)
val sc = new SparkContext(conf)
try {
val rootSQLContext = new SQLContext(sc)
testNewSession(rootSQLContext)
testNewSession(rootSQLContext)
testCreatingNewSQLContext(allowMultipleSQLContexts)
} finally {
sc.stop()
SQLContext.clearInstantiatedContext()
}
}
}
}

View file

@ -40,7 +40,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext {
val newSession = sqlContext.newSession()
assert(SQLContext.getOrCreate(sc).eq(sqlContext),
"SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
SQLContext.setActive(newSession)
SparkSession.setActiveSession(newSession.sparkSession)
assert(SQLContext.getOrCreate(sc).eq(newSession),
"SQLContext.getOrCreate after explicitly setActive() did not return the active context")
}

View file

@ -1042,7 +1042,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SET commands semantics using sql()") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
val testKey = "test.key.0"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
@ -1083,17 +1083,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql(s"SET $nonexistentKey"),
Row(nonexistentKey, "<undefined>")
)
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
}
test("SET commands with illegal or inappropriate argument") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
// Set negative mapred.reduce.tasks for automatically determining
// the number of reducers is not supported
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01"))
intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2"))
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
}
test("apply schema") {

View file

@ -25,6 +25,6 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
val spark = SparkSession.builder.getOrCreate()
new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped)
new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
}
}

View file

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.{SparkContext, SparkFunSuite}
/**
* Test cases for the builder pattern of [[SparkSession]].
*/
class SparkSessionBuilderSuite extends SparkFunSuite {
private var initialSession: SparkSession = _
private lazy val sparkContext: SparkContext = {
initialSession = SparkSession.builder()
.master("local")
.config("spark.ui.enabled", value = false)
.config("some-config", "v2")
.getOrCreate()
initialSession.sparkContext
}
test("create with config options and propagate them to SparkContext and SparkSession") {
// Creating a new session with config - this works by just calling the lazy val
sparkContext
assert(initialSession.sparkContext.conf.get("some-config") == "v2")
assert(initialSession.conf.get("some-config") == "v2")
SparkSession.clearDefaultSession()
}
test("use global default session") {
val session = SparkSession.builder().getOrCreate()
assert(SparkSession.builder().getOrCreate() == session)
SparkSession.clearDefaultSession()
}
test("config options are propagated to existing SparkSession") {
val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate()
assert(session1.conf.get("spark-config1") == "a")
val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate()
assert(session1 == session2)
assert(session1.conf.get("spark-config1") == "b")
SparkSession.clearDefaultSession()
}
test("use session from active thread session and propagate config options") {
val defaultSession = SparkSession.builder().getOrCreate()
val activeSession = defaultSession.newSession()
SparkSession.setActiveSession(activeSession)
val session = SparkSession.builder().config("spark-config2", "a").getOrCreate()
assert(activeSession != defaultSession)
assert(session == activeSession)
assert(session.conf.get("spark-config2") == "a")
SparkSession.clearActiveSession()
assert(SparkSession.builder().getOrCreate() == defaultSession)
SparkSession.clearDefaultSession()
}
test("create a new session if the default session has been stopped") {
val defaultSession = SparkSession.builder().getOrCreate()
SparkSession.setDefaultSession(defaultSession)
defaultSession.stop()
val newSession = SparkSession.builder().master("local").getOrCreate()
assert(newSession != defaultSession)
newSession.stop()
}
test("create a new session if the active thread session has been stopped") {
val activeSession = SparkSession.builder().master("local").getOrCreate()
SparkSession.setActiveSession(activeSession)
activeSession.stop()
val newSession = SparkSession.builder().master("local").getOrCreate()
assert(newSession != activeSession)
newSession.stop()
}
}

View file

@ -26,9 +26,9 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType))
assert(df.queryExecution.analyzed.statistics.sizeInBytes >
spark.wrapped.conf.autoBroadcastJoinThreshold)
spark.sessionState.conf.autoBroadcastJoinThreshold)
assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes >
spark.wrapped.conf.autoBroadcastJoinThreshold)
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
}

View file

@ -27,21 +27,21 @@ import org.apache.spark.sql.internal.SQLConf
class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
private var originalActiveSQLContext: Option[SQLContext] = _
private var originalInstantiatedSQLContext: Option[SQLContext] = _
private var originalActiveSQLContext: Option[SparkSession] = _
private var originalInstantiatedSQLContext: Option[SparkSession] = _
override protected def beforeAll(): Unit = {
originalActiveSQLContext = SQLContext.getActive()
originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
originalActiveSQLContext = SparkSession.getActiveSession
originalInstantiatedSQLContext = SparkSession.getDefaultSession
SQLContext.clearActive()
SQLContext.clearInstantiatedContext()
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}
override protected def afterAll(): Unit = {
// Set these states back.
originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx))
originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx))
originalActiveSQLContext.foreach(ctx => SparkSession.setActiveSession(ctx))
originalInstantiatedSQLContext.foreach(ctx => SparkSession.setDefaultSession(ctx))
}
private def checkEstimation(

View file

@ -155,7 +155,7 @@ class PlannerSuite extends SharedSQLContext {
val path = file.getCanonicalPath
testData.write.parquet(path)
val df = spark.read.parquet(path)
spark.wrapped.registerDataFrameAsTable(df, "testPushed")
spark.sqlContext.registerDataFrameAsTable(df, "testPushed")
withTempTable("testPushed") {
val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan

View file

@ -91,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest
.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match {
.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@ -115,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match {
input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}

View file

@ -90,7 +90,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T], tableName: String, testVectorized: Boolean = true)
(f: => Unit): Unit = {
withParquetDataFrame(data, testVectorized) { df =>
spark.wrapped.registerDataFrameAsTable(df, tableName)
spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}

View file

@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val opId = 0
val rdd1 =
makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(
spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@ -82,7 +82,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
spark: SparkSession,
seq: Seq[String],
storeVersion: Int): RDD[(String, Int)] = {
implicit val sqlContext = spark.wrapped
implicit val sqlContext = spark.sqlContext
makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
}
@ -102,7 +102,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("usage with iterators - only gets and only puts") {
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
implicit val sqlContext = spark.wrapped
implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
@ -131,7 +131,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
@ -150,7 +150,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
implicit val sqlContext = spark.wrapped
implicit val sqlContext = spark.sqlContext
val coordinatorRef = sqlContext.streams.stateStoreCoordinator
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
@ -183,7 +183,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
SparkSession.builder
.config(sparkConf.setMaster("local-cluster[2, 1, 1024]"))
.getOrCreate()) { spark =>
implicit val sqlContext = spark.wrapped
implicit val sqlContext = spark.sqlContext
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
val opId = 0
val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(

View file

@ -24,7 +24,7 @@ import org.mockito.Mockito.mock
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@ -400,8 +400,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
.set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly
val sc = new SparkContext(conf)
try {
SQLContext.clearSqlListener()
val spark = new SQLContext(sc)
SparkSession.sqlListener.set(null)
val spark = new SparkSession(sc)
import spark.implicits._
// Run 100 successful executions and 100 failed executions.
// Each execution only has one job and one stage.

View file

@ -35,7 +35,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
// Set a conf first.
spark.conf.set(testKey, testVal)
// Clear the conf.
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
// After clear, only overrideConfs used by unit test should be in the SQLConf.
assert(spark.conf.getAll === TestSQLContext.overrideConfs)
@ -50,11 +50,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.getAll.contains(testKey))
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
}
test("parse SQL set commands") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
sql(s"set $testKey=$testVal")
assert(spark.conf.get(testKey, testVal + "_") === testVal)
assert(spark.conf.get(testKey, testVal + "_") === testVal)
@ -72,11 +72,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
sql(s"set $key=")
assert(spark.conf.get(key, "0") === "")
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
}
test("set command for display") {
spark.wrapped.conf.clear()
spark.sessionState.conf.clear()
checkAnswer(
sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"),
Nil)
@ -97,7 +97,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("deprecated property") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
try{
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
@ -108,7 +108,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("invalid conf value") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
val e = intercept[IllegalArgumentException] {
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
@ -116,7 +116,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") {
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100")
assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100)
@ -144,7 +144,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
}
spark.wrapped.conf.clear()
spark.sqlContext.conf.clear()
}
test("SparkSession can access configs set in SparkConf") {

View file

@ -41,7 +41,7 @@ case class SimpleDDLScan(
table: String)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(Seq(

View file

@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S
extends BaseRelation
with PrunedFilteredScan {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(

View file

@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa
extends BaseRelation
with PrunedScan {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(

View file

@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider {
case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession)
extends BaseRelation with TableScan {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType =
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
@ -70,7 +70,7 @@ case class AllDataTypesScan(
extends BaseRelation
with TableScan {
override def sqlContext: SQLContext = sparkSession.wrapped
override def sqlContext: SQLContext = sparkSession.sqlContext
override def schema: StructType = userSpecifiedSchema

View file

@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
q.stop()
verify(LastOptions.mockStreamSourceProvider).createSource(
spark.wrapped,
spark.sqlContext,
checkpointLocation + "/sources/0",
None,
"org.apache.spark.sql.streaming.test",
Map.empty)
verify(LastOptions.mockStreamSourceProvider).createSource(
spark.wrapped,
spark.sqlContext,
checkpointLocation + "/sources/1",
None,
"org.apache.spark.sql.streaming.test",

View file

@ -30,7 +30,7 @@ private[sql] trait SQLTestData { self =>
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.spark.wrapped
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
import internalImplicits._

View file

@ -66,7 +66,7 @@ private[sql] trait SQLTestUtils
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.spark.wrapped
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
/**

View file

@ -44,13 +44,13 @@ trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
protected implicit def sqlContext: SQLContext = _spark.wrapped
protected implicit def sqlContext: SQLContext = _spark.sqlContext
/**
* Initialize the [[TestSparkSession]].
*/
protected override def beforeAll(): Unit = {
SQLContext.clearSqlListener()
SparkSession.sqlListener.set(null)
if (_spark == null) {
_spark = new TestSparkSession(sparkConf)
}

View file

@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging {
val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate()
sparkContext = sparkSession.sparkContext
sqlContext = sparkSession.wrapped
sqlContext = sparkSession.sqlContext
val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8"))

View file

@ -30,7 +30,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
override protected def beforeEach(): Unit = {
super.beforeEach()
if (spark.wrapped.tableNames().contains("src")) {
if (spark.sqlContext.tableNames().contains("src")) {
spark.catalog.dropTempView("src")
}
Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src")

View file

@ -36,11 +36,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
assert(spark.wrapped.tableNames().contains("t"))
assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df)
}
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@ -50,7 +50,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
test(s"saveAsTable() to non-default database - without USE - Overwrite") {
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
checkTablePath(db, "t")
@ -65,7 +65,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable("t", path, "parquet")
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table("t"), df)
sql(
@ -76,7 +76,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
assert(spark.wrapped.tableNames(db).contains("t1"))
assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table("t1"), df)
}
}
@ -90,7 +90,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
spark.catalog.createExternalTable(s"$db.t", path, "parquet")
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df)
sql(
@ -101,7 +101,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
| path '$path'
|)
""".stripMargin)
assert(spark.wrapped.tableNames(db).contains("t1"))
assert(spark.sqlContext.tableNames(db).contains("t1"))
checkAnswer(spark.table(s"$db.t1"), df)
}
}
@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
df.write.mode(SaveMode.Append).saveAsTable("t")
assert(spark.wrapped.tableNames().contains("t"))
assert(spark.sqlContext.tableNames().contains("t"))
checkAnswer(spark.table("t"), df.union(df))
}
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t")
df.write.mode(SaveMode.Append).saveAsTable(s"$db.t")
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
checkAnswer(spark.table(s"$db.t"), df.union(df))
checkTablePath(db, "t")
@ -138,7 +138,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
assert(spark.wrapped.tableNames().contains("t"))
assert(spark.sqlContext.tableNames().contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@ -150,10 +150,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
df.write.mode(SaveMode.Overwrite).saveAsTable("t")
assert(spark.wrapped.tableNames().contains("t"))
assert(spark.sqlContext.tableNames().contains("t"))
}
assert(spark.wrapped.tableNames(db).contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
df.write.insertInto(s"$db.t")
checkAnswer(spark.table(s"$db.t"), df.union(df))
@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
withTempDatabase { db =>
activateDatabase(db) {
sql(s"CREATE TABLE t (key INT)")
assert(spark.wrapped.tableNames().contains("t"))
assert(!spark.wrapped.tableNames("default").contains("t"))
assert(spark.sqlContext.tableNames().contains("t"))
assert(!spark.sqlContext.tableNames("default").contains("t"))
}
assert(!spark.wrapped.tableNames().contains("t"))
assert(spark.wrapped.tableNames(db).contains("t"))
assert(!spark.sqlContext.tableNames().contains("t"))
assert(spark.sqlContext.tableNames(db).contains("t"))
activateDatabase(db) {
sql(s"DROP TABLE t")
assert(!spark.wrapped.tableNames().contains("t"))
assert(!spark.wrapped.tableNames("default").contains("t"))
assert(!spark.sqlContext.tableNames().contains("t"))
assert(!spark.sqlContext.tableNames("default").contains("t"))
}
assert(!spark.wrapped.tableNames().contains("t"))
assert(!spark.wrapped.tableNames(db).contains("t"))
assert(!spark.sqlContext.tableNames().contains("t"))
assert(!spark.sqlContext.tableNames(db).contains("t"))
}
}

View file

@ -1417,7 +1417,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin)
checkAnswer(
spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"),
spark.sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
Row(true)
)

View file

@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
spark.wrapped.registerDataFrameAsTable(df, tableName)
spark.sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}