[SPARK-33817][SQL] CACHE TABLE uses a logical plan when caching a query to avoid creating a dataframe

### What changes were proposed in this pull request?

This PR proposes to update `CACHE TABLE` to use a `LogicalPlan` when caching a query to avoid creating a `DataFrame` as suggested here: https://github.com/apache/spark/pull/30743#discussion_r543123190

For reference, `UNCACHE TABLE` also uses `LogicalPlan`: 0c12900120/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala (L91-L98)

### Why are the changes needed?

To avoid creating an unnecessary dataframe and make it consistent with `uncacheQuery` used in `UNCACHE TABLE`.

### Does this PR introduce _any_ user-facing change?

No, just internal changes.

### How was this patch tested?

Existing tests since this is an internal refactoring change.

Closes #30815 from imback82/cache_with_logical_plan.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Terry Kim 2020-12-18 04:30:15 +00:00 committed by Wenchen Fan
parent 131a23d88a
commit 0f1a18370a
4 changed files with 57 additions and 20 deletions

View file

@ -1107,12 +1107,16 @@ class Analyzer(override val catalogManager: CatalogManager)
case c @ CacheTable(u @ UnresolvedRelation(_, _, false), _, _, _) => case c @ CacheTable(u @ UnresolvedRelation(_, _, false), _, _, _) =>
lookupRelation(u.multipartIdentifier, u.options, false) lookupRelation(u.multipartIdentifier, u.options, false)
.map(relation => c.copy(table = EliminateSubqueryAliases(relation))) .map(resolveViews)
.map(EliminateSubqueryAliases(_))
.map(relation => c.copy(table = relation))
.getOrElse(c) .getOrElse(c)
case c @ UncacheTable(u @ UnresolvedRelation(_, _, false), _, _) => case c @ UncacheTable(u @ UnresolvedRelation(_, _, false), _, _) =>
lookupRelation(u.multipartIdentifier, u.options, false) lookupRelation(u.multipartIdentifier, u.options, false)
.map(relation => c.copy(table = EliminateSubqueryAliases(relation))) .map(resolveViews)
.map(EliminateSubqueryAliases(_))
.map(relation => c.copy(table = relation))
.getOrElse(c) .getOrElse(c)
// TODO (SPARK-27484): handle streaming write commands when we have them. // TODO (SPARK-27484): handle streaming write commands when we have them.

View file

@ -88,12 +88,34 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
query: Dataset[_], query: Dataset[_],
tableName: Option[String] = None, tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = { storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = {
val planToCache = query.logicalPlan cacheQuery(query.sparkSession, query.logicalPlan, tableName, storageLevel)
}
/**
* Caches the data produced by the given [[LogicalPlan]].
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
* recomputing the in-memory columnar representation of the underlying table is expensive.
*/
def cacheQuery(
spark: SparkSession,
planToCache: LogicalPlan,
tableName: Option[String]): Unit = {
cacheQuery(spark, planToCache, tableName, MEMORY_AND_DISK)
}
/**
* Caches the data produced by the given [[LogicalPlan]].
*/
def cacheQuery(
spark: SparkSession,
planToCache: LogicalPlan,
tableName: Option[String],
storageLevel: StorageLevel): Unit = {
if (lookupCachedData(planToCache).nonEmpty) { if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.") logWarning("Asked to cache already cached data.")
} else { } else {
val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff( val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff(
query.sparkSession, forceDisableConfigs) spark, forceDisableConfigs)
val inMemoryRelation = sessionWithConfigsOff.withActive { val inMemoryRelation = sessionWithConfigsOff.withActive {
val qe = sessionWithConfigsOff.sessionState.executePlan(planToCache) val qe = sessionWithConfigsOff.sessionState.executePlan(planToCache)
InMemoryRelation( InMemoryRelation(

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.Locale import java.util.Locale
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@ -29,10 +29,13 @@ import org.apache.spark.storage.StorageLevel
trait BaseCacheTableExec extends V2CommandExec { trait BaseCacheTableExec extends V2CommandExec {
def relationName: String def relationName: String
def dataFrameToCache: DataFrame def planToCache: LogicalPlan
def dataFrameForCachedPlan: DataFrame
def isLazy: Boolean def isLazy: Boolean
def options: Map[String, String] def options: Map[String, String]
protected val sparkSession: SparkSession = sqlContext.sparkSession
override def run(): Seq[InternalRow] = { override def run(): Seq[InternalRow] = {
val storageLevelKey = "storagelevel" val storageLevelKey = "storagelevel"
val storageLevelValue = val storageLevelValue =
@ -42,20 +45,22 @@ trait BaseCacheTableExec extends V2CommandExec {
logWarning(s"Invalid options: ${withoutStorageLevel.mkString(", ")}") logWarning(s"Invalid options: ${withoutStorageLevel.mkString(", ")}")
} }
val sparkSession = sqlContext.sparkSession
val df = dataFrameToCache
if (storageLevelValue.nonEmpty) { if (storageLevelValue.nonEmpty) {
sparkSession.sharedState.cacheManager.cacheQuery( sparkSession.sharedState.cacheManager.cacheQuery(
df, sparkSession,
planToCache,
Some(relationName), Some(relationName),
StorageLevel.fromString(storageLevelValue.get)) StorageLevel.fromString(storageLevelValue.get))
} else { } else {
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(relationName)) sparkSession.sharedState.cacheManager.cacheQuery(
sparkSession,
planToCache,
Some(relationName))
} }
if (!isLazy) { if (!isLazy) {
// Performs eager caching // Performs eager caching.
df.count() dataFrameForCachedPlan.count()
} }
Seq.empty Seq.empty
@ -69,9 +74,13 @@ case class CacheTableExec(
multipartIdentifier: Seq[String], multipartIdentifier: Seq[String],
override val isLazy: Boolean, override val isLazy: Boolean,
override val options: Map[String, String]) extends BaseCacheTableExec { override val options: Map[String, String]) extends BaseCacheTableExec {
override def relationName: String = multipartIdentifier.quoted override lazy val relationName: String = multipartIdentifier.quoted
override def dataFrameToCache: DataFrame = Dataset.ofRows(sqlContext.sparkSession, relation) override lazy val planToCache: LogicalPlan = relation
override lazy val dataFrameForCachedPlan: DataFrame = {
Dataset.ofRows(sparkSession, planToCache)
}
} }
case class CacheTableAsSelectExec( case class CacheTableAsSelectExec(
@ -79,11 +88,14 @@ case class CacheTableAsSelectExec(
query: LogicalPlan, query: LogicalPlan,
override val isLazy: Boolean, override val isLazy: Boolean,
override val options: Map[String, String]) extends BaseCacheTableExec { override val options: Map[String, String]) extends BaseCacheTableExec {
override def relationName: String = tempViewName override lazy val relationName: String = tempViewName
override def dataFrameToCache: DataFrame = { override lazy val planToCache: LogicalPlan = {
val sparkSession = sqlContext.sparkSession
Dataset.ofRows(sparkSession, query).createTempView(tempViewName) Dataset.ofRows(sparkSession, query).createTempView(tempViewName)
dataFrameForCachedPlan.logicalPlan
}
override lazy val dataFrameForCachedPlan: DataFrame = {
sparkSession.table(tempViewName) sparkSession.table(tempViewName)
} }
} }

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession, Strategy} import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@ -66,8 +66,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel
// recache with the same name and cache level. // recache with the same name and cache level.
val ds = Dataset.ofRows(session, v2Relation) session.sharedState.cacheManager.cacheQuery(session, v2Relation, cacheName, cacheLevel)
session.sharedState.cacheManager.cacheQuery(ds, cacheName, cacheLevel)
} }
} }