[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table
## What changes were proposed in this pull request? When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. This PR also includes some refactors: 1. use `java.util.LinkedList` to store the cache entries, so that it's safer to remove elements while iterating 2. rename `invalidateCache` to `recacheByPlan`, which is more obvious about what it does. ## How was this patch tested? new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #17097 from cloud-fan/cache.
This commit is contained in:
parent
030acdd1f0
commit
c05baabf10
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
|
@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
|
|||
class CacheManager extends Logging {
|
||||
|
||||
@transient
|
||||
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
|
||||
private val cachedData = new java.util.LinkedList[CachedData]
|
||||
|
||||
@transient
|
||||
private val cacheLock = new ReentrantReadWriteLock
|
||||
|
@ -70,7 +72,7 @@ class CacheManager extends Logging {
|
|||
|
||||
/** Clears all cached tables. */
|
||||
def clearCache(): Unit = writeLock {
|
||||
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
|
||||
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
|
||||
cachedData.clear()
|
||||
}
|
||||
|
||||
|
@ -88,46 +90,81 @@ class CacheManager extends Logging {
|
|||
query: Dataset[_],
|
||||
tableName: Option[String] = None,
|
||||
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
|
||||
val planToCache = query.queryExecution.analyzed
|
||||
val planToCache = query.logicalPlan
|
||||
if (lookupCachedData(planToCache).nonEmpty) {
|
||||
logWarning("Asked to cache already cached data.")
|
||||
} else {
|
||||
val sparkSession = query.sparkSession
|
||||
cachedData +=
|
||||
CachedData(
|
||||
planToCache,
|
||||
InMemoryRelation(
|
||||
sparkSession.sessionState.conf.useCompression,
|
||||
sparkSession.sessionState.conf.columnBatchSize,
|
||||
storageLevel,
|
||||
sparkSession.sessionState.executePlan(planToCache).executedPlan,
|
||||
tableName))
|
||||
cachedData.add(CachedData(
|
||||
planToCache,
|
||||
InMemoryRelation(
|
||||
sparkSession.sessionState.conf.useCompression,
|
||||
sparkSession.sessionState.conf.columnBatchSize,
|
||||
storageLevel,
|
||||
sparkSession.sessionState.executePlan(planToCache).executedPlan,
|
||||
tableName)))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to remove the data for the given [[Dataset]] from the cache.
|
||||
* No operation, if it's already uncached.
|
||||
* Un-cache all the cache entries that refer to the given plan.
|
||||
*/
|
||||
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
|
||||
val planToCache = query.queryExecution.analyzed
|
||||
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
|
||||
val found = dataIndex >= 0
|
||||
if (found) {
|
||||
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
|
||||
cachedData.remove(dataIndex)
|
||||
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
|
||||
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
|
||||
}
|
||||
|
||||
/**
|
||||
* Un-cache all the cache entries that refer to the given plan.
|
||||
*/
|
||||
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
|
||||
val it = cachedData.iterator()
|
||||
while (it.hasNext) {
|
||||
val cd = it.next()
|
||||
if (cd.plan.find(_.sameResult(plan)).isDefined) {
|
||||
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
|
||||
it.remove()
|
||||
}
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to re-cache all the cache entries that refer to the given plan.
|
||||
*/
|
||||
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
|
||||
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
|
||||
}
|
||||
|
||||
private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
|
||||
val it = cachedData.iterator()
|
||||
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
|
||||
while (it.hasNext) {
|
||||
val cd = it.next()
|
||||
if (condition(cd.plan)) {
|
||||
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
|
||||
// Remove the cache entry before we create a new one, so that we can have a different
|
||||
// physical plan.
|
||||
it.remove()
|
||||
val newCache = InMemoryRelation(
|
||||
useCompression = cd.cachedRepresentation.useCompression,
|
||||
batchSize = cd.cachedRepresentation.batchSize,
|
||||
storageLevel = cd.cachedRepresentation.storageLevel,
|
||||
child = spark.sessionState.executePlan(cd.plan).executedPlan,
|
||||
tableName = cd.cachedRepresentation.tableName)
|
||||
needToRecache += cd.copy(cachedRepresentation = newCache)
|
||||
}
|
||||
}
|
||||
|
||||
needToRecache.foreach(cachedData.add)
|
||||
}
|
||||
|
||||
/** Optionally returns cached data for the given [[Dataset]] */
|
||||
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
|
||||
lookupCachedData(query.queryExecution.analyzed)
|
||||
lookupCachedData(query.logicalPlan)
|
||||
}
|
||||
|
||||
/** Optionally returns cached data for the given [[LogicalPlan]]. */
|
||||
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
|
||||
cachedData.find(cd => plan.sameResult(cd.plan))
|
||||
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
|
||||
}
|
||||
|
||||
/** Replaces segments of the given logical plan with cached versions where possible. */
|
||||
|
@ -145,40 +182,17 @@ class CacheManager extends Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
|
||||
* function will over invalidate.
|
||||
*/
|
||||
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
|
||||
cachedData.foreach {
|
||||
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
|
||||
data.cachedRepresentation.recache()
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidates the cache of any data that contains `resourcePath` in one or more
|
||||
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
|
||||
* `HadoopFsRelation` node(s) as part of its logical plan.
|
||||
*/
|
||||
def invalidateCachedPath(
|
||||
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
|
||||
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
|
||||
val (fs, qualifiedPath) = {
|
||||
val path = new Path(resourcePath)
|
||||
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
|
||||
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
|
||||
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
|
||||
(fs, fs.makeQualified(path))
|
||||
}
|
||||
|
||||
cachedData.filter {
|
||||
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true
|
||||
case _ => false
|
||||
}.foreach { data =>
|
||||
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
|
||||
if (dataIndex >= 0) {
|
||||
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
|
||||
cachedData.remove(dataIndex)
|
||||
}
|
||||
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
|
||||
}
|
||||
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -85,12 +85,6 @@ case class InMemoryRelation(
|
|||
buildBuffers()
|
||||
}
|
||||
|
||||
def recache(): Unit = {
|
||||
_cachedColumnBuffers.unpersist()
|
||||
_cachedColumnBuffers = null
|
||||
buildBuffers()
|
||||
}
|
||||
|
||||
private def buildBuffers(): Unit = {
|
||||
val output = child.output
|
||||
val cached = child.execute().mapPartitionsInternal { rowIterator =>
|
||||
|
|
|
@ -199,8 +199,7 @@ case class DropTableCommand(
|
|||
}
|
||||
}
|
||||
try {
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(
|
||||
sparkSession.table(tableName.quotedString))
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
|
||||
} catch {
|
||||
case _: NoSuchTableException if ifExists =>
|
||||
case NonFatal(e) => log.warn(e.toString, e)
|
||||
|
|
|
@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
|
|||
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
|
||||
relation.insert(df, overwrite)
|
||||
|
||||
// Invalidate the cache.
|
||||
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
|
||||
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
|
||||
// data source relation.
|
||||
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
|
||||
|
||||
Seq.empty[Row]
|
||||
}
|
||||
|
|
|
@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
override def dropTempView(viewName: String): Boolean = {
|
||||
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
|
||||
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
|
||||
sessionCatalog.dropTempView(viewName)
|
||||
}
|
||||
}
|
||||
|
@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
|
|||
*/
|
||||
override def dropGlobalTempView(viewName: String): Boolean = {
|
||||
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
|
||||
sessionCatalog.dropGlobalTempView(viewName)
|
||||
}
|
||||
}
|
||||
|
@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
override def uncacheTable(tableName: String): Unit = {
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
|
|||
|
||||
// If this table is cached as an InMemoryRelation, drop the original
|
||||
// cached version and make the new version cached lazily.
|
||||
val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed
|
||||
// Use lookupCachedData directly since RefreshTable also takes databaseName.
|
||||
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
|
||||
if (isCached) {
|
||||
// Create a data frame to represent the table.
|
||||
// TODO: Use uncacheTable once it supports database name.
|
||||
val df = Dataset.ofRows(sparkSession, logicalPlan)
|
||||
val table = sparkSession.table(tableIdent)
|
||||
if (isCached(table)) {
|
||||
// Uncache the logicalPlan.
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
|
||||
// Cache it again.
|
||||
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
|
||||
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
|
|||
* @since 2.0.0
|
||||
*/
|
||||
override def refreshByPath(resourcePath: String): Unit = {
|
||||
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
|
||||
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,15 +24,15 @@ import scala.language.postfixOps
|
|||
import org.scalatest.concurrent.Eventually._
|
||||
|
||||
import org.apache.spark.CleanerListener
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
|
||||
import org.apache.spark.sql.execution.RDDScanExec
|
||||
import org.apache.spark.sql.execution.columnar._
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchange
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
|
||||
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
|
||||
import org.apache.spark.util.AccumulatorContext
|
||||
import org.apache.spark.util.{AccumulatorContext, Utils}
|
||||
|
||||
private case class BigData(s: String)
|
||||
|
||||
|
@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
maybeBlock.nonEmpty
|
||||
}
|
||||
|
||||
private def getNumInMemoryRelations(plan: LogicalPlan): Int = {
|
||||
private def getNumInMemoryRelations(ds: Dataset[_]): Int = {
|
||||
val plan = ds.queryExecution.withCachedData
|
||||
var sum = plan.collect { case _: InMemoryRelation => 1 }.sum
|
||||
plan.transformAllExpressions {
|
||||
case e: SubqueryExpression =>
|
||||
|
@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
assertCached(spark.table("testData"))
|
||||
|
||||
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
|
||||
getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData)
|
||||
getNumInMemoryRelations(spark.table("testData"))
|
||||
}
|
||||
|
||||
spark.catalog.cacheTable("testData")
|
||||
|
@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
localRelation.createOrReplaceTempView("localRelation")
|
||||
|
||||
spark.catalog.cacheTable("localRelation")
|
||||
assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1)
|
||||
assert(getNumInMemoryRelations(localRelation) == 1)
|
||||
}
|
||||
|
||||
test("SPARK-19093 Caching in side subquery") {
|
||||
withTempView("t1") {
|
||||
Seq(1).toDF("c1").createOrReplaceTempView("t1")
|
||||
spark.catalog.cacheTable("t1")
|
||||
val cachedPlan =
|
||||
val ds =
|
||||
sql(
|
||||
"""
|
||||
|SELECT * FROM t1
|
||||
|WHERE
|
||||
|NOT EXISTS (SELECT * FROM t1)
|
||||
""".stripMargin).queryExecution.optimizedPlan
|
||||
assert(getNumInMemoryRelations(cachedPlan) == 2)
|
||||
""".stripMargin)
|
||||
assert(getNumInMemoryRelations(ds) == 2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
spark.catalog.cacheTable("t4")
|
||||
|
||||
// Nested predicate subquery
|
||||
val cachedPlan =
|
||||
val ds =
|
||||
sql(
|
||||
"""
|
||||
|SELECT * FROM t1
|
||||
|WHERE
|
||||
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
|
||||
""".stripMargin).queryExecution.optimizedPlan
|
||||
assert(getNumInMemoryRelations(cachedPlan) == 3)
|
||||
""".stripMargin)
|
||||
assert(getNumInMemoryRelations(ds) == 3)
|
||||
|
||||
// Scalar subquery and predicate subquery
|
||||
val cachedPlan2 =
|
||||
val ds2 =
|
||||
sql(
|
||||
"""
|
||||
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
|
||||
|
@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|
|||
|EXISTS (SELECT c1 FROM t3)
|
||||
|OR
|
||||
|c1 IN (SELECT c1 FROM t4)
|
||||
""".stripMargin).queryExecution.optimizedPlan
|
||||
assert(getNumInMemoryRelations(cachedPlan2) == 4)
|
||||
""".stripMargin)
|
||||
assert(getNumInMemoryRelations(ds2) == 4)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") {
|
||||
withTable("t") {
|
||||
withTempPath { path =>
|
||||
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
|
||||
sql(s"CREATE TABLE t USING parquet LOCATION '$path'")
|
||||
spark.catalog.cacheTable("t")
|
||||
spark.table("t").select($"i").cache()
|
||||
checkAnswer(spark.table("t").select($"i"), Row(1))
|
||||
assertCached(spark.table("t").select($"i"))
|
||||
|
||||
Utils.deleteRecursively(path)
|
||||
spark.sessionState.catalog.refreshTable(TableIdentifier("t"))
|
||||
spark.catalog.uncacheTable("t")
|
||||
assert(spark.table("t").select($"i").count() == 0)
|
||||
assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -393,8 +393,8 @@ case class InsertIntoHiveTable(
|
|||
logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
|
||||
}
|
||||
|
||||
// Invalidate the cache.
|
||||
sparkSession.catalog.uncacheTable(table.qualifiedName)
|
||||
// un-cache this table.
|
||||
sparkSession.catalog.uncacheTable(table.identifier.quotedString)
|
||||
sparkSession.sessionState.catalog.refreshTable(table.identifier)
|
||||
|
||||
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
|
||||
|
|
|
@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
|
|||
tempPath.delete()
|
||||
table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString)
|
||||
sql("DROP TABLE IF EXISTS refreshTable")
|
||||
sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet")
|
||||
checkAnswer(
|
||||
table("refreshTable"),
|
||||
table("src").collect())
|
||||
sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet")
|
||||
checkAnswer(table("refreshTable"), table("src"))
|
||||
// Cache the table.
|
||||
sql("CACHE TABLE refreshTable")
|
||||
assertCached(table("refreshTable"))
|
||||
|
|
|
@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
|
|||
// Converted test_parquet should be cached.
|
||||
sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match {
|
||||
case null => fail("Converted test_parquet should be cached in the cache.")
|
||||
case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK
|
||||
case LogicalRelation(_: HadoopFsRelation, _, _) => // OK
|
||||
case other =>
|
||||
fail(
|
||||
"The cached test_parquet should be a Parquet Relation. " +
|
||||
|
|
Loading…
Reference in a new issue