diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 526623a36d..4ca1347008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -131,12 +132,16 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { - plan transformDown { + val newPlan = plan transformDown { case currentFragment => lookupCachedData(currentFragment) .map(_.cachedRepresentation.withOutput(currentFragment.output)) .getOrElse(currentFragment) } + + newPlan transformAllExpressions { + case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f42402e1cc..fb4812adf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,6 +24,8 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener +import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange @@ -37,6 +39,14 @@ private case class BigData(s: String) class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ + override def afterEach(): Unit = { + try { + spark.catalog.clearCache() + } finally { + super.afterEach() + } + } + def rddIdOf(tableName: String): Int = { val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { @@ -53,6 +63,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } + private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + var sum = plan.collect { case _: InMemoryRelation => 1 }.sum + plan.transformAllExpressions { + case e: SubqueryExpression => + sum += getNumInMemoryRelations(e.plan) + e + } + sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -565,4 +585,56 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext case i: InMemoryRelation => i }.size == 1) } + + test("SPARK-19093 Caching in side subquery") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + val cachedPlan = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan) == 2) + } + } + + test("SPARK-19093 scalar and nested predicate query") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.cacheTable("t3") + spark.catalog.cacheTable("t4") + + // Nested predicate subquery + val cachedPlan = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan) == 3) + + // Scalar subquery and predicate subquery + val cachedPlan2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).queryExecution.optimizedPlan + assert(getNumInMemoryRelations(cachedPlan2) == 4) + } + } }