[SPARK-35798][SQL] Fix SparkPlan.sqlContext usage

### What changes were proposed in this pull request?
There might be `SparkPlan` nodes where canonicalization on executor side can cause issues. This is a follow-up fix to conversation https://github.com/apache/spark/pull/32885/files#r651019687.

### Why are the changes needed?
To avoid potential NPEs when canonicalization happens on executors.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UTs.

Closes #32947 from peter-toth/SPARK-35798-fix-sparkplan.sqlcontext-usage.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Peter Toth 2021-06-17 13:49:38 +00:00 committed by Wenchen Fan
parent b86a69f026
commit abf9675a75
35 changed files with 97 additions and 99 deletions

View file

@ -62,7 +62,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
override def doExecute(): RDD[InternalRow] = {
val broadcastedHadoopConf =
new SerializableConfiguration(sqlContext.sessionState.newHadoopConf())
new SerializableConfiguration(session.sessionState.newHadoopConf())
child.execute().mapPartitions { iter =>
if (iter.hasNext) {

View file

@ -441,7 +441,7 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition {
)
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled
val enableOffHeapColumnVector = conf.offHeapColumnVectorEnabled
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do

View file

@ -51,11 +51,11 @@ case class CommandResultExec(
@transient private lazy val rdd: RDD[InternalRow] = {
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
sparkContext.emptyRDD
} else {
val numSlices = math.min(
unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
unsafeRows.length, session.leafNodeDefaultParallelism)
sparkContext.parallelize(unsafeRows, numSlices)
}
}

View file

@ -54,7 +54,7 @@ trait DataSourceScanExec extends LeafExecNode {
// Metadata that describes more details of this scan.
protected def metadata: Map[String, String]
protected val maxMetadataValueLength = sqlContext.sessionState.conf.maxMetadataStringLength
protected val maxMetadataValueLength = conf.maxMetadataStringLength
override def simpleString(maxFields: Int): String = {
val metadataEntries = metadata.toSeq.sorted.map {
@ -86,7 +86,7 @@ trait DataSourceScanExec extends LeafExecNode {
* Shorthand for calling redactString() without specifying redacting rules
*/
protected def redact(text: String): String = {
Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text)
Utils.redact(conf.stringRedactionPattern, text)
}
/**
@ -179,7 +179,7 @@ case class FileSourceScanExec(
private lazy val needsUnsafeRowConversion: Boolean = {
if (relation.fileFormat.isInstanceOf[ParquetSource]) {
sqlContext.conf.parquetVectorizedReaderEnabled
conf.parquetVectorizedReaderEnabled
} else {
false
}

View file

@ -47,11 +47,11 @@ case class LocalTableScanExec(
@transient private lazy val rdd: RDD[InternalRow] = {
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
sparkContext.emptyRDD
} else {
val numSlices = math.min(
unsafeRows.length, sqlContext.sparkSession.leafNodeDefaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
unsafeRows.length, session.leafNodeDefaultParallelism)
sparkContext.parallelize(unsafeRows, numSlices)
}
}

View file

@ -55,7 +55,7 @@ case class SortExec(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
private val enableRadixSort = sqlContext.conf.enableRadixSort
private val enableRadixSort = conf.enableRadixSort
override lazy val metrics = Map(
"sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"),

View file

@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
object SparkPlan {
@ -55,15 +56,17 @@ object SparkPlan {
* The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]].
*/
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
@transient final val session = SparkSession.getActiveSession.orNull
/**
* A handle to the SQL Context that was used to create this plan. Since many operators need
* access to the sqlContext for RDD operations or configuration this field is automatically
* populated by the query planning infrastructure.
*/
@transient final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull
protected def sparkContext = session.sparkContext
protected def sparkContext = sqlContext.sparkContext
override def conf: SQLConf = {
if (session != null) {
session.sessionState.conf
} else {
super.conf
}
}
val id: Int = SparkPlan.newPlanId()
@ -80,8 +83,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
if (sqlContext != null) {
SparkSession.setActiveSession(sqlContext.sparkSession)
if (session != null) {
SparkSession.setActiveSession(session)
}
super.makeCopy(newArgs)
}
@ -448,7 +451,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
// If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate
// it by 50%. We also cap the estimation in the end.
val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
@ -467,7 +470,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
} else {
parts
}
val sc = sqlContext.sparkContext
val sc = sparkContext
val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
if (it.hasNext) it.next() else (0L, Array.emptyByteArray), partsToScan)

View file

@ -74,7 +74,7 @@ case class SubqueryBroadcastExec(
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
SQLExecution.withExecutionId(session, executionId) {
val beforeCollect = System.nanoTime()
val broadcastRelation = child.executeBroadcast[HashedRelation]().value

View file

@ -724,17 +724,17 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val (_, compiledCodeStats) = try {
CodeGenerator.compile(cleanedSource)
} catch {
case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback =>
case NonFatal(_) if !Utils.isTesting && conf.codegenFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
return child.execute()
}
// Check if compiled code has a too large function
if (compiledCodeStats.maxMethodCodeSize > sqlContext.conf.hugeMethodLimit) {
if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) {
logInfo(s"Found too long generated codes and JIT optimization might not work: " +
s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " +
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
return child.execute()

View file

@ -62,7 +62,7 @@ object AggUtils {
resultExpressions = resultExpressions,
child = child)
} else {
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
val objectHashEnabled = child.conf.useObjectHashAggregation
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
if (objectHashEnabled && useObjectHash) {

View file

@ -73,8 +73,8 @@ case class HashAggregateExec(
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Option(sqlContext).map { sc =>
sc.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null)
Option(session).map { s =>
s.conf.get("spark.sql.TungstenAggregate.testFallbackStartsAt", null)
}.orNull match {
case null | "" => None
case fallbackStartsAt =>
@ -679,15 +679,15 @@ case class HashAggregateExec(
// This is for testing/benchmarking only.
// We enforce to first level to be a vectorized hashmap, instead of the default row-based one.
isVectorizedHashMapEnabled = sqlContext.conf.enableVectorizedHashMap
isVectorizedHashMapEnabled = conf.enableVectorizedHashMap
}
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (sqlContext.conf.enableTwoLevelAggMap) {
if (conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else if (sqlContext.conf.enableVectorizedHashMap) {
} else if (conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
val bitMaxCapacity = testFallbackStartsAt match {
@ -700,7 +700,7 @@ case class HashAggregateExec(
} else {
(math.log10(fastMapCounter) / math.log10(2)).floor.toInt
}
case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit
case _ => conf.fastHashAggregateRowMaxCapacityBit
}
val thisPlan = ctx.addReferenceObj("plan", this)

View file

@ -83,7 +83,7 @@ case class ObjectHashAggregateExec(
val aggTime = longMetric("aggTime")
val spillSize = longMetric("spillSize")
val numTasksFallBacked = longMetric("numTasksFallBacked")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
val fallbackCountThreshold = conf.objectAggSortBasedFallbackThreshold
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
val beforeAgg = System.nanoTime()

View file

@ -46,8 +46,8 @@ case class UpdatingSessionsExec(
groupingWithoutSessionExpression.map(_.toAttribute)
override protected def doExecute(): RDD[InternalRow] = {
val inMemoryThreshold = sqlContext.conf.sessionWindowBufferInMemoryThreshold
val spillThreshold = sqlContext.conf.sessionWindowBufferSpillThreshold
val inMemoryThreshold = conf.sessionWindowBufferInMemoryThreshold
val spillThreshold = conf.sessionWindowBufferSpillThreshold
child.execute().mapPartitions { iter =>
new UpdatingSessionsIterator(iter, groupingExpression, sessionExpression,

View file

@ -413,7 +413,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val start: Long = range.start
val end: Long = range.end
val step: Long = range.step
val numSlices: Int = range.numSlices.getOrElse(sqlContext.sparkSession.leafNodeDefaultParallelism)
val numSlices: Int = range.numSlices.getOrElse(session.leafNodeDefaultParallelism)
val numElements: BigInt = range.numElements
val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step)
@ -442,9 +442,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override def inputRDDs(): Seq[RDD[InternalRow]] = {
val rdd = if (isEmptyRange) {
new EmptyRDD[InternalRow](sqlContext.sparkContext)
new EmptyRDD[InternalRow](sparkContext)
} else {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
}
rdd :: Nil
}
@ -608,10 +608,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
if (isEmptyRange) {
new EmptyRDD[InternalRow](sqlContext.sparkContext)
new EmptyRDD[InternalRow](sparkContext)
} else {
sqlContext
.sparkContext
sparkContext
.parallelize(0 until numSlices, numSlices)
.mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
@ -814,11 +813,11 @@ case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int]
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
sqlContext.sparkSession,
session,
SubqueryExec.executionContext) {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
SQLExecution.withExecutionId(session, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = if (maxNumRows.isDefined) {

View file

@ -208,8 +208,8 @@ case class CachedRDDBuilder(
@transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null
val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
val rowCountStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator
val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator
val cachedName = tableName.map(n => s"In-memory table $n")
.getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))

View file

@ -132,13 +132,13 @@ case class InMemoryTableScanExec(
override def outputOrdering: Seq[SortOrder] =
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
lazy val enableAccumulatorsForTest: Boolean = sqlContext.conf.inMemoryTableScanStatisticsEnabled
lazy val enableAccumulatorsForTest: Boolean = conf.inMemoryTableScanStatisticsEnabled
// Accumulators used for testing purposes
lazy val readPartitions = sparkContext.longAccumulator
lazy val readBatches = sparkContext.longAccumulator
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
private val inMemoryPartitionPruningEnabled = conf.inMemoryPartitionPruning
private def filteredCachedBatches(): RDD[CachedBatch] = {
val buffers = relation.cacheBuilder.cachedColumnBuffers

View file

@ -72,7 +72,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
cmd.run(session).map(converter(_).asInstanceOf[InternalRow])
}
override def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil
@ -92,7 +92,7 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
}
protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
sparkContext.parallelize(sideEffectResult, 1)
}
}
@ -110,7 +110,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val rows = cmd.run(sqlContext.sparkSession, child)
val rows = cmd.run(session, child)
rows.map(converter(_).asInstanceOf[InternalRow])
}
@ -133,7 +133,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
}
protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
sparkContext.parallelize(sideEffectResult, 1)
}
override protected def withNewChildInternal(newChild: SparkPlan): DataWritingCommandExec =

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.Locale
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.LocalTempView
import org.apache.spark.sql.catalyst.expressions.Attribute
@ -36,8 +36,6 @@ trait BaseCacheTableExec extends LeafV2CommandExec {
def isLazy: Boolean
def options: Map[String, String]
protected val sparkSession: SparkSession = sqlContext.sparkSession
override def run(): Seq[InternalRow] = {
val storageLevelKey = "storagelevel"
val storageLevelValue =
@ -48,14 +46,14 @@ trait BaseCacheTableExec extends LeafV2CommandExec {
}
if (storageLevelValue.nonEmpty) {
sparkSession.sharedState.cacheManager.cacheQuery(
sparkSession,
session.sharedState.cacheManager.cacheQuery(
session,
planToCache,
Some(relationName),
StorageLevel.fromString(storageLevelValue.get))
} else {
sparkSession.sharedState.cacheManager.cacheQuery(
sparkSession,
session.sharedState.cacheManager.cacheQuery(
session,
planToCache,
Some(relationName))
}
@ -81,7 +79,7 @@ case class CacheTableExec(
override lazy val planToCache: LogicalPlan = relation
override lazy val dataFrameForCachedPlan: DataFrame = {
Dataset.ofRows(sparkSession, planToCache)
Dataset.ofRows(session, planToCache)
}
}
@ -105,13 +103,13 @@ case class CacheTableAsSelectExec(
replace = false,
viewType = LocalTempView,
isAnalyzed = true
).run(sparkSession)
).run(session)
dataFrameForCachedPlan.logicalPlan
}
override lazy val dataFrameForCachedPlan: DataFrame = {
sparkSession.table(tempViewName)
session.table(tempViewName)
}
}
@ -119,8 +117,7 @@ case class UncacheTableExec(
relation: LogicalPlan,
cascade: Boolean) extends LeafV2CommandExec {
override def run(): Seq[InternalRow] = {
val sparkSession = sqlContext.sparkSession
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)
session.sharedState.cacheManager.uncacheQuery(session, relation, cascade)
Seq.empty
}

View file

@ -54,8 +54,8 @@ case class ContinuousScanExec(
.askSync[Unit](SetReaderPartitions(partitions.size))
new ContinuousDataSourceRDD(
sparkContext,
sqlContext.conf.continuousStreamingExecutorQueueSize,
sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
conf.continuousStreamingExecutorQueueSize,
conf.continuousStreamingExecutorPollIntervalMs,
partitions,
schema,
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory],

View file

@ -57,7 +57,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
* Shorthand for calling redact() without specifying redacting rules
*/
protected def redact(text: String): String = {
Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text)
Utils.redact(session.sessionState.conf.stringRedactionPattern, text)
}
override def verboseStringWithOperatorId(): String = {

View file

@ -54,7 +54,7 @@ case class RenameTableExec(
val tbl = catalog.loadTable(qualifiedNewIdent)
val newRelation = DataSourceV2Relation.create(tbl, Some(catalog), Some(qualifiedNewIdent))
cacheTable(
sqlContext.sparkSession,
session,
newRelation,
Some(qualifiedNewIdent.quoted), oldStorageLevel)
}

View file

@ -76,7 +76,7 @@ trait SupportsV1Write extends SparkPlan {
def plan: LogicalPlan
protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = {
relation.insert(Dataset.ofRows(sqlContext.sparkSession, plan), overwrite = false)
relation.insert(Dataset.ofRows(session, plan), overwrite = false)
Nil
}
}

View file

@ -55,7 +55,7 @@ abstract class V2CommandExec extends SparkPlan {
override def executeTail(limit: Int): Array[InternalRow] = result.takeRight(limit).toArray
protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(result, 1)
sparkContext.parallelize(result, 1)
}
override def producedAttributes: AttributeSet = outputSet

View file

@ -122,7 +122,7 @@ case class BroadcastExchangeExec(
@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
session, BroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",

View file

@ -71,7 +71,7 @@ case class BroadcastHashJoinExec(
override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
@ -112,7 +112,7 @@ case class BroadcastHashJoinExec(
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit
val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0
def generateExprCombinations(

View file

@ -80,8 +80,8 @@ case class CartesianProductExec(
val pair = new UnsafeCartesianRDD(
leftResults,
rightResults,
sqlContext.conf.cartesianProductExecBufferInMemoryThreshold,
sqlContext.conf.cartesianProductExecBufferSpillThreshold)
conf.cartesianProductExecBufferInMemoryThreshold,
conf.cartesianProductExecBufferSpillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {

View file

@ -102,7 +102,7 @@ case class SortMergeJoinExec(
UnsafeProjection.create(rightKeys, right.output)
private def getSpillThreshold: Int = {
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
conf.sortMergeJoinExecBufferSpillThreshold
}
// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
@ -115,7 +115,7 @@ case class SortMergeJoinExec(
if (onlyBufferFirstMatchedRow) {
1
} else {
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
conf.sortMergeJoinExecBufferInMemoryThreshold
}
}

View file

@ -109,8 +109,8 @@ case class FlatMapGroupsWithStateExec(
groupingAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val commitTimeMs = longMetric("commitTimeMs")

View file

@ -176,10 +176,10 @@ case class StreamingSymmetricHashJoinExec(
errorMessageForJoinType)
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))
private val storeConf = new StateStoreConf(sqlContext.conf)
private val storeConf = new StateStoreConf(conf)
private val hadoopConfBcast = sparkContext.broadcast(
new SerializableConfiguration(SessionState.newHadoopConf(
sparkContext.hadoopConfiguration, sqlContext.conf)))
sparkContext.hadoopConfiguration, conf)))
val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length)
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
@ -219,7 +219,7 @@ case class StreamingSymmetricHashJoinExec(
}
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreCoord = session.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
left.execute().stateStoreAwareZipPartitions(
right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)

View file

@ -124,7 +124,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
}
private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
provider.supportedCustomMetrics.map {
metric => (metric.name, metric.createSQLMetric(sparkContext))
}.toMap
@ -246,8 +246,8 @@ case class StateStoreRestoreExec(
keyExpressions.toStructType,
stateManager.getStateValueSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
@ -308,8 +308,8 @@ case class StateStoreSaveExec(
keyExpressions.toStructType,
stateManager.getStateValueSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
val numOutputRows = longMetric("numOutputRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
@ -461,8 +461,8 @@ case class StreamingDeduplicateExec(
keyExpressions.toStructType,
child.output.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
// We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
// is unrelated to the output schema.
Map(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false")) { (store, iter) =>

View file

@ -53,8 +53,8 @@ case class StreamingGlobalLimitExec(
keySchema,
valueSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
val numOutputRows = longMetric("numOutputRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")

View file

@ -133,7 +133,7 @@ case class InSubqueryExec(
} else {
rows.map(_.get(0, child.dataType))
}
resultBroadcast = plan.sqlContext.sparkContext.broadcast(result)
resultBroadcast = plan.session.sparkContext.broadcast(result)
}
def values(): Option[Array[Any]] = Option(resultBroadcast).map(_.value)

View file

@ -114,8 +114,8 @@ case class WindowExec(
// Unwrap the window expressions and window frame factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold
val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
val spillThreshold = conf.windowExecBufferSpillThreshold
// Start processing.
child.execute().mapPartitions { stream =>

View file

@ -192,8 +192,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {
override protected def doExecute(): RDD[InternalRow] = {
sqlContext
.sparkContext
sparkContext
.parallelize(0 until 2, 2)
.mapPartitions { it =>
val confs = SQLConf.get

View file

@ -202,11 +202,11 @@ case class HiveTableScanExec(
// Using dummyCallSite, as getCallSite can turn out to be expensive with
// multiple partitions.
val rdd = if (!relation.isPartitioned) {
Utils.withDummyCallSite(sqlContext.sparkContext) {
Utils.withDummyCallSite(sparkContext) {
hadoopReader.makeRDDForTable(hiveQlTable)
}
} else {
Utils.withDummyCallSite(sqlContext.sparkContext) {
Utils.withDummyCallSite(sparkContext) {
hadoopReader.makeRDDForPartitionedTable(prunedPartitions)
}
}