[SPARK-33567][SQL] DSv2: Use callback instead of passing Spark session and v2 relation for refreshing cache

### What changes were proposed in this pull request?

This replaces Spark session and `DataSourceV2Relation` in V2 write plans by replacing them with a callback `afterWrite`.

### Why are the changes needed?

Per discussion in #30429, it's better to not pass Spark session and `DataSourceV2Relation` through Spark plans. Instead we can use a callback which makes the interface cleaner.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

N/A

Closes #30491 from sunchao/SPARK-33492-followup.

Authored-by: Chao Sun <sunchao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Chao Sun 2020-11-30 04:50:50 +00:00 committed by Wenchen Fan
parent a5e13acd19
commit feda7299e3
5 changed files with 43 additions and 41 deletions

View file

@ -52,6 +52,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
} }
} }
private def refreshCache(r: DataSourceV2Relation)(): Unit = {
session.sharedState.cacheManager.recacheByPlan(session, r)
}
private def invalidateCache(r: ResolvedTable)(): Unit = {
val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier))
session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
}
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters, case PhysicalOperation(project, filters,
relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) =>
@ -128,7 +137,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
} }
case RefreshTable(r: ResolvedTable) => case RefreshTable(r: ResolvedTable) =>
RefreshTableExec(session, r.catalog, r.table, r.identifier) :: Nil RefreshTableExec(r.catalog, r.identifier, invalidateCache(r)) :: Nil
case ReplaceTable(catalog, ident, schema, parts, props, orCreate) => case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) val propsWithOwner = CatalogV2Util.withDefaultOwnership(props)
@ -172,9 +181,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => case AppendData(r: DataSourceV2Relation, query, writeOptions, _) =>
r.table.asWritable match { r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil AppendDataExecV1(v1, writeOptions.asOptions, query, refreshCache(r)) :: Nil
case v2 => case v2 =>
AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil AppendDataExec(v2, writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil
} }
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) =>
@ -186,15 +195,16 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}.toArray }.toArray
r.table.asWritable match { r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions,
query, refreshCache(r)) :: Nil
case v2 => case v2 =>
OverwriteByExpressionExec(session, v2, r, filters, OverwriteByExpressionExec(v2, filters,
writeOptions.asOptions, planLater(query)) :: Nil writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil
} }
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) =>
OverwritePartitionsDynamicExec( OverwritePartitionsDynamicExec(
session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil r.table.asWritable, writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil
case DeleteFromTable(relation, condition) => case DeleteFromTable(relation, condition) =>
relation match { relation match {
@ -232,7 +242,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
throw new AnalysisException("Describing columns is not supported for v2 tables.") throw new AnalysisException("Describing columns is not supported for v2 tables.")
case DropTable(r: ResolvedTable, ifExists, purge) => case DropTable(r: ResolvedTable, ifExists, purge) =>
DropTableExec(session, r.catalog, r.table, r.identifier, ifExists, purge) :: Nil DropTableExec(r.catalog, r.identifier, ifExists, purge, invalidateCache(r)) :: Nil
case _: NoopDropTable => case _: NoopDropTable =>
LocalTableScanExec(Nil, Nil) :: Nil LocalTableScanExec(Nil, Nil) :: Nil

View file

@ -17,27 +17,24 @@
package org.apache.spark.sql.execution.datasources.v2 package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
/** /**
* Physical plan node for dropping a table. * Physical plan node for dropping a table.
*/ */
case class DropTableExec( case class DropTableExec(
session: SparkSession,
catalog: TableCatalog, catalog: TableCatalog,
table: Table,
ident: Identifier, ident: Identifier,
ifExists: Boolean, ifExists: Boolean,
purge: Boolean) extends V2CommandExec { purge: Boolean,
invalidateCache: () => Unit) extends V2CommandExec {
override def run(): Seq[InternalRow] = { override def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) { if (catalog.tableExists(ident)) {
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) invalidateCache()
session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
catalog.dropTable(ident, purge) catalog.dropTable(ident, purge)
} else if (!ifExists) { } else if (!ifExists) {
throw new NoSuchTableException(ident) throw new NoSuchTableException(ident)

View file

@ -17,23 +17,20 @@
package org.apache.spark.sql.execution.datasources.v2 package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
case class RefreshTableExec( case class RefreshTableExec(
session: SparkSession,
catalog: TableCatalog, catalog: TableCatalog,
table: Table, ident: Identifier,
ident: Identifier) extends V2CommandExec { invalidateCache: () => Unit) extends V2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
catalog.invalidateTable(ident) catalog.invalidateTable(ident)
// invalidate all caches referencing the given table // invalidate all caches referencing the given table
// TODO(SPARK-33437): re-cache the table itself once we support caching a DSv2 table // TODO(SPARK-33437): re-cache the table itself once we support caching a DSv2 table
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) invalidateCache()
session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
Seq.empty Seq.empty
} }

View file

@ -38,10 +38,10 @@ case class AppendDataExecV1(
table: SupportsWrite, table: SupportsWrite,
writeOptions: CaseInsensitiveStringMap, writeOptions: CaseInsensitiveStringMap,
plan: LogicalPlan, plan: LogicalPlan,
v2Relation: DataSourceV2Relation) extends V1FallbackWriters { refreshCache: () => Unit) extends V1FallbackWriters {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation)) writeWithV1(newWriteBuilder().buildForV1Write(), refreshCache = refreshCache)
} }
} }
@ -61,7 +61,7 @@ case class OverwriteByExpressionExecV1(
deleteWhere: Array[Filter], deleteWhere: Array[Filter],
writeOptions: CaseInsensitiveStringMap, writeOptions: CaseInsensitiveStringMap,
plan: LogicalPlan, plan: LogicalPlan,
v2Relation: DataSourceV2Relation) extends V1FallbackWriters { refreshCache: () => Unit) extends V1FallbackWriters {
private def isTruncate(filters: Array[Filter]): Boolean = { private def isTruncate(filters: Array[Filter]): Boolean = {
filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
@ -70,10 +70,11 @@ case class OverwriteByExpressionExecV1(
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
newWriteBuilder() match { newWriteBuilder() match {
case builder: SupportsTruncate if isTruncate(deleteWhere) => case builder: SupportsTruncate if isTruncate(deleteWhere) =>
writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation)) writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), refreshCache = refreshCache)
case builder: SupportsOverwrite => case builder: SupportsOverwrite =>
writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation)) writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(),
refreshCache = refreshCache)
case _ => case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table") throw new SparkException(s"Table does not support overwrite by expression: $table")
@ -116,11 +117,11 @@ trait SupportsV1Write extends SparkPlan {
protected def writeWithV1( protected def writeWithV1(
relation: InsertableRelation, relation: InsertableRelation,
v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = { refreshCache: () => Unit = () => ()): Seq[InternalRow] = {
val session = sqlContext.sparkSession val session = sqlContext.sparkSession
// The `plan` is already optimized, we should not analyze and optimize it again. // The `plan` is already optimized, we should not analyze and optimize it again.
relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false)
v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r)) refreshCache()
Nil Nil
} }

View file

@ -213,15 +213,14 @@ case class AtomicReplaceTableAsSelectExec(
* Rows in the output data set are appended. * Rows in the output data set are appended.
*/ */
case class AppendDataExec( case class AppendDataExec(
session: SparkSession,
table: SupportsWrite, table: SupportsWrite,
relation: DataSourceV2Relation,
writeOptions: CaseInsensitiveStringMap, writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { query: SparkPlan,
refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) val writtenRows = writeWithV2(newWriteBuilder().buildForBatch())
session.sharedState.cacheManager.recacheByPlan(session, relation) refreshCache()
writtenRows writtenRows
} }
} }
@ -237,12 +236,11 @@ case class AppendDataExec(
* AlwaysTrue to delete all rows. * AlwaysTrue to delete all rows.
*/ */
case class OverwriteByExpressionExec( case class OverwriteByExpressionExec(
session: SparkSession,
table: SupportsWrite, table: SupportsWrite,
relation: DataSourceV2Relation,
deleteWhere: Array[Filter], deleteWhere: Array[Filter],
writeOptions: CaseInsensitiveStringMap, writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { query: SparkPlan,
refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper {
private def isTruncate(filters: Array[Filter]): Boolean = { private def isTruncate(filters: Array[Filter]): Boolean = {
filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
@ -259,7 +257,7 @@ case class OverwriteByExpressionExec(
case _ => case _ =>
throw new SparkException(s"Table does not support overwrite by expression: $table") throw new SparkException(s"Table does not support overwrite by expression: $table")
} }
session.sharedState.cacheManager.recacheByPlan(session, relation) refreshCache()
writtenRows writtenRows
} }
} }
@ -275,11 +273,10 @@ case class OverwriteByExpressionExec(
* are not modified. * are not modified.
*/ */
case class OverwritePartitionsDynamicExec( case class OverwritePartitionsDynamicExec(
session: SparkSession,
table: SupportsWrite, table: SupportsWrite,
relation: DataSourceV2Relation,
writeOptions: CaseInsensitiveStringMap, writeOptions: CaseInsensitiveStringMap,
query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { query: SparkPlan,
refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val writtenRows = newWriteBuilder() match { val writtenRows = newWriteBuilder() match {
@ -289,7 +286,7 @@ case class OverwritePartitionsDynamicExec(
case _ => case _ =>
throw new SparkException(s"Table does not support dynamic partition overwrite: $table") throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
} }
session.sharedState.cacheManager.recacheByPlan(session, relation) refreshCache()
writtenRows writtenRows
} }
} }