diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 995d94ef5e..9f177819f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -177,9 +177,19 @@ object SQLExecution { val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) Future { + val originalSession = SparkSession.getActiveSession + val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) sc.setLocalProperties(localProps) - body + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + if (originalSession.nonEmpty) { + SparkSession.setActiveSession(originalSession.get) + } else { + SparkSession.clearActiveSession() + } + res }(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 0cc658c499..46d0c64592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.UUID + import org.scalatest.Assertions._ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} @@ -144,16 +146,16 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } // set local configuration and assert - val confValue1 = "e" + val confValue1 = UUID.randomUUID().toString() createDataframe(confKey, confValue1).createOrReplaceTempView("m") spark.sparkContext.setLocalProperty(confKey, confValue1) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect.size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect().length == 1) // change the conf value and assert again - val confValue2 = "f" + val confValue2 = UUID.randomUUID().toString() createDataframe(confKey, confValue2).createOrReplaceTempView("n") spark.sparkContext.setLocalProperty(confKey, confValue2) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().length == 1) } } }