[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table

## What changes were proposed in this pull request?

When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table.

This PR also includes some refactors:

1. use `java.util.LinkedList` to store the cache entries, so that it's safer to remove elements while iterating
2. rename `invalidateCache` to `recacheByPlan`, which is more obvious about what it does.

## How was this patch tested?

new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17097 from cloud-fan/cache.
This commit is contained in:
Wenchen Fan 2017-03-07 09:21:58 -08:00 committed by Xiao Li
parent 030acdd1f0
commit c05baabf10
9 changed files with 119 additions and 98 deletions

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {
@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
private val cachedData = new java.util.LinkedList[CachedData]
@transient
private val cacheLock = new ReentrantReadWriteLock
@ -70,7 +72,7 @@ class CacheManager extends Logging {
/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}
@ -88,46 +90,81 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
cachedData +=
CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName))
cachedData.add(CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName)))
}
}
/**
* Tries to remove the data for the given [[Dataset]] from the cache.
* No operation, if it's already uncached.
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
if (found) {
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData.remove(dataIndex)
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
}
/**
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
val it = cachedData.iterator()
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
it.remove()
}
}
found
}
/**
* Tries to re-cache all the cache entries that refer to the given plan.
*/
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
}
private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
needToRecache.foreach(cachedData.add)
}
/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
lookupCachedData(query.logicalPlan)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
cachedData.find(cd => plan.sameResult(cd.plan))
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}
/** Replaces segments of the given logical plan with cached versions where possible. */
@ -145,40 +182,17 @@ class CacheManager extends Logging {
}
/**
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
* function will over invalidate.
*/
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
cachedData.foreach {
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
data.cachedRepresentation.recache()
case _ =>
}
}
/**
* Invalidates the cache of any data that contains `resourcePath` in one or more
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
def invalidateCachedPath(
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
(fs, fs.makeQualified(path))
}
cachedData.filter {
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true
case _ => false
}.foreach { data =>
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
if (dataIndex >= 0) {
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
cachedData.remove(dataIndex)
}
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
}
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}
/**

View file

@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}
def recache(): Unit = {
_cachedColumnBuffers.unpersist()
_cachedColumnBuffers = null
buildBuffers()
}
private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>

View file

@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName.quotedString))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)

View file

@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)
// Invalidate the cache.
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
// data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
Seq.empty[Row]
}

View file

@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}
/**
@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed
// Use lookupCachedData directly since RefreshTable also takes databaseName.
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
// Create a data frame to represent the table.
// TODO: Use uncacheTable once it supports database name.
val df = Dataset.ofRows(sparkSession, logicalPlan)
val table = sparkSession.table(tableIdent)
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}
@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}

View file

@ -24,15 +24,15 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.CleanerListener
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.AccumulatorContext
import org.apache.spark.util.{AccumulatorContext, Utils}
private case class BigData(s: String)
@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
maybeBlock.nonEmpty
}
private def getNumInMemoryRelations(plan: LogicalPlan): Int = {
private def getNumInMemoryRelations(ds: Dataset[_]): Int = {
val plan = ds.queryExecution.withCachedData
var sum = plan.collect { case _: InMemoryRelation => 1 }.sum
plan.transformAllExpressions {
case e: SubqueryExpression =>
@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assertCached(spark.table("testData"))
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData)
getNumInMemoryRelations(spark.table("testData"))
}
spark.catalog.cacheTable("testData")
@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
localRelation.createOrReplaceTempView("localRelation")
spark.catalog.cacheTable("localRelation")
assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1)
assert(getNumInMemoryRelations(localRelation) == 1)
}
test("SPARK-19093 Caching in side subquery") {
withTempView("t1") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
spark.catalog.cacheTable("t1")
val cachedPlan =
val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan) == 2)
""".stripMargin)
assert(getNumInMemoryRelations(ds) == 2)
}
}
@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.catalog.cacheTable("t4")
// Nested predicate subquery
val cachedPlan =
val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan) == 3)
""".stripMargin)
assert(getNumInMemoryRelations(ds) == 3)
// Scalar subquery and predicate subquery
val cachedPlan2 =
val ds2 =
sql(
"""
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|EXISTS (SELECT c1 FROM t3)
|OR
|c1 IN (SELECT c1 FROM t4)
""".stripMargin).queryExecution.optimizedPlan
assert(getNumInMemoryRelations(cachedPlan2) == 4)
""".stripMargin)
assert(getNumInMemoryRelations(ds2) == 4)
}
}
test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") {
withTable("t") {
withTempPath { path =>
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
sql(s"CREATE TABLE t USING parquet LOCATION '$path'")
spark.catalog.cacheTable("t")
spark.table("t").select($"i").cache()
checkAnswer(spark.table("t").select($"i"), Row(1))
assertCached(spark.table("t").select($"i"))
Utils.deleteRecursively(path)
spark.sessionState.catalog.refreshTable(TableIdentifier("t"))
spark.catalog.uncacheTable("t")
assert(spark.table("t").select($"i").count() == 0)
assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0)
}
}
}

View file

@ -393,8 +393,8 @@ case class InsertIntoHiveTable(
logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
}
// Invalidate the cache.
sparkSession.catalog.uncacheTable(table.qualifiedName)
// un-cache this table.
sparkSession.catalog.uncacheTable(table.identifier.quotedString)
sparkSession.sessionState.catalog.refreshTable(table.identifier)
// It would be nice to just return the childRdd unchanged so insert operations could be chained,

View file

@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
tempPath.delete()
table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString)
sql("DROP TABLE IF EXISTS refreshTable")
sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet")
checkAnswer(
table("refreshTable"),
table("src").collect())
sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet")
checkAnswer(table("refreshTable"), table("src"))
// Cache the table.
sql("CACHE TABLE refreshTable")
assertCached(table("refreshTable"))

View file

@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
// Converted test_parquet should be cached.
sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match {
case null => fail("Converted test_parquet should be cached in the cache.")
case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK
case LogicalRelation(_: HadoopFsRelation, _, _) => // OK
case other =>
fail(
"The cached test_parquet should be a Parquet Relation. " +