[SPARK-30556][SQL][FOLLOWUP] Reset the status changed in SQLExecution.withThreadLocalCaptured

### What changes were proposed in this pull request?
Follow up for #27267, reset the status changed in SQLExecution.withThreadLocalCaptured.

### Why are the changes needed?
For code safety.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
Existing UT.

Closes #27516 from xuanyuanking/SPARK-30556-follow.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
Yuanjian Li 2020-02-10 22:16:25 +01:00 committed by herman
parent 3c1c9b48fc
commit a6b91d2bf7
2 changed files with 17 additions and 5 deletions

View file

@ -177,9 +177,19 @@ object SQLExecution {
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
Future {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
sc.setLocalProperties(localProps)
body
val res = body
// reset active session and local props.
sc.setLocalProperties(originalLocalProps)
if (originalSession.nonEmpty) {
SparkSession.setActiveSession(originalSession.get)
} else {
SparkSession.clearActiveSession()
}
res
}(exec)
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.internal
import java.util.UUID
import org.scalatest.Assertions._
import org.apache.spark.{SparkException, SparkFunSuite, TaskContext}
@ -144,16 +146,16 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
}
// set local configuration and assert
val confValue1 = "e"
val confValue1 = UUID.randomUUID().toString()
createDataframe(confKey, confValue1).createOrReplaceTempView("m")
spark.sparkContext.setLocalProperty(confKey, confValue1)
assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect.size == 1)
assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect().length == 1)
// change the conf value and assert again
val confValue2 = "f"
val confValue2 = UUID.randomUUID().toString()
createDataframe(confKey, confValue2).createOrReplaceTempView("n")
spark.sparkContext.setLocalProperty(confKey, confValue2)
assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().size == 1)
assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().length == 1)
}
}
}