[SPARK-6844][SQL] Clean up accumulators used in InMemoryRelation when it is uncached
JIRA: https://issues.apache.org/jira/browse/SPARK-6844 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5475 from viirya/cache_memory_leak and squashes the following commits: 0b41235 [Liang-Chi Hsieh] fix style. dc1d5d5 [Liang-Chi Hsieh] For comments. 78af229 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cache_memory_leak 26c9bb6 [Liang-Chi Hsieh] Add configuration to enable in-memory table scan accumulators. 1c3b06e [Liang-Chi Hsieh] Clean up accumulators used in InMemoryRelation when it is uncached.
This commit is contained in:
parent
85842760dc
commit
cf38fe04f8
|
@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
|
|||
val planToCache = query.queryExecution.analyzed
|
||||
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
|
||||
require(dataIndex >= 0, s"Table $query is not cached.")
|
||||
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
|
||||
cachedData(dataIndex).cachedRepresentation.uncache(blocking)
|
||||
cachedData.remove(dataIndex)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar
|
|||
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import org.apache.spark.Accumulator
|
||||
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
|
||||
import org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation(
|
|||
child: SparkPlan,
|
||||
tableName: Option[String])(
|
||||
private var _cachedColumnBuffers: RDD[CachedBatch] = null,
|
||||
private var _statistics: Statistics = null)
|
||||
private var _statistics: Statistics = null,
|
||||
private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null)
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
||||
private val batchStats =
|
||||
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
|
||||
private val batchStats: Accumulable[ArrayBuffer[Row], Row] =
|
||||
if (_batchStats == null) {
|
||||
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
|
||||
} else {
|
||||
_batchStats
|
||||
}
|
||||
|
||||
val partitionStatistics = new PartitionStatistics(output)
|
||||
|
||||
|
@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation(
|
|||
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
|
||||
InMemoryRelation(
|
||||
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
|
||||
_cachedColumnBuffers, statisticsToBePropagated)
|
||||
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
|
||||
}
|
||||
|
||||
override def children: Seq[LogicalPlan] = Seq.empty
|
||||
|
@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation(
|
|||
child,
|
||||
tableName)(
|
||||
_cachedColumnBuffers,
|
||||
statisticsToBePropagated).asInstanceOf[this.type]
|
||||
statisticsToBePropagated,
|
||||
batchStats).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
|
||||
|
||||
override protected def otherCopyArgs: Seq[AnyRef] =
|
||||
Seq(_cachedColumnBuffers, statisticsToBePropagated)
|
||||
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
|
||||
|
||||
private[sql] def uncache(blocking: Boolean): Unit = {
|
||||
Accumulators.remove(batchStats.id)
|
||||
cachedColumnBuffers.unpersist(blocking)
|
||||
_cachedColumnBuffers = null
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] case class InMemoryColumnarTableScan(
|
||||
|
@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan(
|
|||
}
|
||||
}
|
||||
|
||||
lazy val enableAccumulators: Boolean =
|
||||
sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
|
||||
|
||||
// Accumulators used for testing purposes
|
||||
val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
|
||||
val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
|
||||
lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
|
||||
lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
|
||||
|
||||
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
|
||||
|
||||
override def execute(): RDD[Row] = {
|
||||
readPartitions.setValue(0)
|
||||
readBatches.setValue(0)
|
||||
if (enableAccumulators) {
|
||||
readPartitions.setValue(0)
|
||||
readBatches.setValue(0)
|
||||
}
|
||||
|
||||
relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
|
||||
val partitionFilter = newPredicate(
|
||||
|
@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan(
|
|||
}
|
||||
}
|
||||
|
||||
if (rows.hasNext) {
|
||||
if (rows.hasNext && enableAccumulators) {
|
||||
readPartitions += 1
|
||||
}
|
||||
|
||||
|
@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan(
|
|||
logInfo(s"Skipping partition based on stats $statsString")
|
||||
false
|
||||
} else {
|
||||
readBatches += 1
|
||||
if (enableAccumulators) {
|
||||
readBatches += 1
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps}
|
|||
|
||||
import org.scalatest.concurrent.Eventually._
|
||||
|
||||
import org.apache.spark.Accumulators
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.columnar._
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest {
|
|||
sql("Clear CACHE")
|
||||
assert(cacheManager.isEmpty)
|
||||
}
|
||||
|
||||
test("Clear accumulators when uncacheTable to prevent memory leaking") {
|
||||
val accsSize = Accumulators.originals.size
|
||||
|
||||
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
|
||||
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
|
||||
cacheTable("t1")
|
||||
cacheTable("t2")
|
||||
sql("SELECT * FROM t1").count()
|
||||
sql("SELECT * FROM t2").count()
|
||||
sql("SELECT * FROM t1").count()
|
||||
sql("SELECT * FROM t2").count()
|
||||
uncacheTable("t1")
|
||||
uncacheTable("t2")
|
||||
|
||||
assert(accsSize >= Accumulators.originals.size)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
|
|||
|
||||
// Enable in-memory partition pruning
|
||||
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
|
||||
// Enable in-memory table scan accumulators
|
||||
setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
|
||||
}
|
||||
|
||||
override protected def afterAll(): Unit = {
|
||||
|
|
Loading…
Reference in a new issue