[SPARK-3212][SQL] Use logical plan matching instead of temporary tables for table caching
_Also addresses: SPARK-1671, SPARK-1379 and SPARK-3641_
This PR introduces a new trait, `CacheManger`, which replaces the previous temporary table based caching system. Instead of creating a temporary table that shadows an existing table with and equivalent cached representation, the cached manager maintains a separate list of logical plans and their cached data. After optimization, this list is searched for any matching plan fragments. When a matching plan fragment is found it is replaced with the cached data.
There are several advantages to this approach:
- Calling .cache() on a SchemaRDD now works as you would expect, and uses the more efficient columnar representation.
- Its now possible to provide a list of temporary tables, without having to decide if a given table is actually just a cached persistent table. (To be done in a follow-up PR)
- In some cases it is possible that cached data will be used, even if a cached table was not explicitly requested. This is because we now look at the logical structure instead of the table name.
- We now correctly invalidate when data is inserted into a hive table.
Author: Michael Armbrust <michael@databricks.com>
Closes #2501 from marmbrus/caching and squashes the following commits:
63fbc2c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching.
0ea889e [Michael Armbrust] Address comments.
1e23287 [Michael Armbrust] Add support for cache invalidation for hive inserts.
65ed04a [Michael Armbrust] fix tests.
bdf9a3f
[Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching
b4b77f2 [Michael Armbrust] Address comments
6923c9d [Michael Armbrust] More comments / tests
80f26ac [Michael Armbrust] First draft of improved semantics for Spark SQL caching.
This commit is contained in:
parent
bec0d0eaa3
commit
6a1d48f4f0
|
@ -93,6 +93,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
|
|||
*/
|
||||
object ResolveRelations extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
|
||||
i.copy(
|
||||
table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
|
||||
case UnresolvedRelation(databaseName, name, alias) =>
|
||||
catalog.lookupRelation(databaseName, name, alias)
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
|
|||
def withName(newName: String): Attribute
|
||||
|
||||
def toAttribute = this
|
||||
def newInstance: Attribute
|
||||
def newInstance(): Attribute
|
||||
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
|
|||
h
|
||||
}
|
||||
|
||||
override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
|
||||
override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
|
||||
|
||||
/**
|
||||
* Returns a copy of this [[AttributeReference]] with changed nullability.
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
|
|||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.catalyst.types.StructType
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
|
||||
|
@ -72,6 +73,47 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
|
|||
*/
|
||||
def childrenResolved: Boolean = !children.exists(!_.resolved)
|
||||
|
||||
/**
|
||||
* Returns true when the given logical plan will return the same results as this logical plan.
|
||||
*
|
||||
* Since its likely undecideable to generally determine if two given plans will produce the same
|
||||
* results, it is okay for this function to return false, even if the results are actually
|
||||
* the same. Such behavior will not affect correctness, only the application of performance
|
||||
* enhancements like caching. However, it is not acceptable to return true if the results could
|
||||
* possibly be different.
|
||||
*
|
||||
* By default this function performs a modified version of equality that is tolerant of cosmetic
|
||||
* differences like attribute naming and or expression id differences. Logical operators that
|
||||
* can do better should override this function.
|
||||
*/
|
||||
def sameResult(plan: LogicalPlan): Boolean = {
|
||||
plan.getClass == this.getClass &&
|
||||
plan.children.size == children.size && {
|
||||
logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]")
|
||||
cleanArgs == plan.cleanArgs
|
||||
} &&
|
||||
(plan.children, children).zipped.forall(_ sameResult _)
|
||||
}
|
||||
|
||||
/** Args that have cleaned such that differences in expression id should not affect equality */
|
||||
protected lazy val cleanArgs: Seq[Any] = {
|
||||
val input = children.flatMap(_.output)
|
||||
productIterator.map {
|
||||
// Children are checked using sameResult above.
|
||||
case tn: TreeNode[_] if children contains tn => null
|
||||
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
|
||||
case s: Option[_] => s.map {
|
||||
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
|
||||
case other => other
|
||||
}
|
||||
case s: Seq[_] => s.map {
|
||||
case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
|
||||
case other => other
|
||||
}
|
||||
case other => other
|
||||
}.toSeq
|
||||
}
|
||||
|
||||
/**
|
||||
* Optionally resolves the given string to a [[NamedExpression]] using the input from all child
|
||||
* nodes of this LogicalPlan. The attribute is expressed as
|
||||
|
|
|
@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
|
|||
}
|
||||
|
||||
override protected def stringArgs = Iterator(output)
|
||||
|
||||
override def sameResult(plan: LogicalPlan): Boolean = plan match {
|
||||
case LocalRelation(otherOutput, otherData) =>
|
||||
otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,8 +105,8 @@ case class InsertIntoTable(
|
|||
child: LogicalPlan,
|
||||
overwrite: Boolean)
|
||||
extends LogicalPlan {
|
||||
// The table being inserted into is a child for the purposes of transformations.
|
||||
override def children = table :: child :: Nil
|
||||
|
||||
override def children = child :: Nil
|
||||
override def output = child.output
|
||||
|
||||
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.plans
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
|
||||
/**
|
||||
* Provides helper methods for comparing plans.
|
||||
*/
|
||||
class SameResultSuite extends FunSuite {
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
||||
def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = {
|
||||
val aAnalyzed = a.analyze
|
||||
val bAnalyzed = b.analyze
|
||||
|
||||
if (aAnalyzed.sameResult(bAnalyzed) != result) {
|
||||
val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
|
||||
fail(s"Plans should return sameResult = $result\n$comparison")
|
||||
}
|
||||
}
|
||||
|
||||
test("relations") {
|
||||
assertSameResult(testRelation, testRelation2)
|
||||
}
|
||||
|
||||
test("projections") {
|
||||
assertSameResult(testRelation.select('a), testRelation2.select('a))
|
||||
assertSameResult(testRelation.select('b), testRelation2.select('b))
|
||||
assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
|
||||
assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))
|
||||
|
||||
assertSameResult(testRelation, testRelation2.select('a), false)
|
||||
assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false)
|
||||
}
|
||||
|
||||
test("filters") {
|
||||
assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
|
||||
}
|
||||
}
|
139
sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
Normal file
139
sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
Normal file
|
@ -0,0 +1,139 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.columnar.InMemoryRelation
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
|
||||
|
||||
/** Holds a cached logical plan and its data */
|
||||
private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
|
||||
|
||||
/**
|
||||
* Provides support in a SQLContext for caching query results and automatically using these cached
|
||||
* results when subsequent queries are executed. Data is cached using byte buffers stored in an
|
||||
* InMemoryRelation. This relation is automatically substituted query plans that return the
|
||||
* `sameResult` as the originally cached query.
|
||||
*/
|
||||
private[sql] trait CacheManager {
|
||||
self: SQLContext =>
|
||||
|
||||
@transient
|
||||
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
|
||||
|
||||
@transient
|
||||
private val cacheLock = new ReentrantReadWriteLock
|
||||
|
||||
/** Returns true if the table is currently cached in-memory. */
|
||||
def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty
|
||||
|
||||
/** Caches the specified table in-memory. */
|
||||
def cacheTable(tableName: String): Unit = cacheQuery(table(tableName))
|
||||
|
||||
/** Removes the specified table from the in-memory cache. */
|
||||
def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName))
|
||||
|
||||
/** Acquires a read lock on the cache for the duration of `f`. */
|
||||
private def readLock[A](f: => A): A = {
|
||||
val lock = cacheLock.readLock()
|
||||
lock.lock()
|
||||
try f finally {
|
||||
lock.unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/** Acquires a write lock on the cache for the duration of `f`. */
|
||||
private def writeLock[A](f: => A): A = {
|
||||
val lock = cacheLock.writeLock()
|
||||
lock.lock()
|
||||
try f finally {
|
||||
lock.unlock()
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] def clearCache(): Unit = writeLock {
|
||||
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
|
||||
cachedData.clear()
|
||||
}
|
||||
|
||||
/** Caches the data produced by the logical representation of the given schema rdd. */
|
||||
private[sql] def cacheQuery(
|
||||
query: SchemaRDD,
|
||||
storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock {
|
||||
val planToCache = query.queryExecution.optimizedPlan
|
||||
if (lookupCachedData(planToCache).nonEmpty) {
|
||||
logWarning("Asked to cache already cached data.")
|
||||
} else {
|
||||
cachedData +=
|
||||
CachedData(
|
||||
planToCache,
|
||||
InMemoryRelation(
|
||||
useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan))
|
||||
}
|
||||
}
|
||||
|
||||
/** Removes the data for the given SchemaRDD from the cache */
|
||||
private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock {
|
||||
val planToCache = query.queryExecution.optimizedPlan
|
||||
val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache))
|
||||
|
||||
if (dataIndex < 0) {
|
||||
throw new IllegalArgumentException(s"Table $query is not cached.")
|
||||
}
|
||||
|
||||
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
|
||||
cachedData.remove(dataIndex)
|
||||
}
|
||||
|
||||
|
||||
/** Optionally returns cached data for the given SchemaRDD */
|
||||
private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
|
||||
lookupCachedData(query.queryExecution.optimizedPlan)
|
||||
}
|
||||
|
||||
/** Optionally returns cached data for the given LogicalPlan. */
|
||||
private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
|
||||
cachedData.find(_.plan.sameResult(plan))
|
||||
}
|
||||
|
||||
/** Replaces segments of the given logical plan with cached versions where possible. */
|
||||
private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
|
||||
plan transformDown {
|
||||
case currentFragment =>
|
||||
lookupCachedData(currentFragment)
|
||||
.map(_.cachedRepresentation.withOutput(currentFragment.output))
|
||||
.getOrElse(currentFragment)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
|
||||
* function will over invalidate.
|
||||
*/
|
||||
private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock {
|
||||
cachedData.foreach {
|
||||
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
|
||||
data.cachedRepresentation.recache()
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext}
|
|||
class SQLContext(@transient val sparkContext: SparkContext)
|
||||
extends org.apache.spark.Logging
|
||||
with SQLConf
|
||||
with CacheManager
|
||||
with ExpressionConversions
|
||||
with UDFRegistration
|
||||
with Serializable {
|
||||
|
@ -96,7 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
*/
|
||||
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
|
||||
SparkPlan.currentContext.set(self)
|
||||
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self))
|
||||
new SchemaRDD(this,
|
||||
LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -133,7 +135,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
|
||||
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
|
||||
// schema differs from the existing schema on any field data type.
|
||||
val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self)
|
||||
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
|
||||
new SchemaRDD(this, logicalPlan)
|
||||
}
|
||||
|
||||
|
@ -272,45 +274,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
def table(tableName: String): SchemaRDD =
|
||||
new SchemaRDD(this, catalog.lookupRelation(None, tableName))
|
||||
|
||||
/** Caches the specified table in-memory. */
|
||||
def cacheTable(tableName: String): Unit = {
|
||||
val currentTable = table(tableName).queryExecution.analyzed
|
||||
val asInMemoryRelation = currentTable match {
|
||||
case _: InMemoryRelation =>
|
||||
currentTable
|
||||
|
||||
case _ =>
|
||||
InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan)
|
||||
}
|
||||
|
||||
catalog.registerTable(None, tableName, asInMemoryRelation)
|
||||
}
|
||||
|
||||
/** Removes the specified table from the in-memory cache. */
|
||||
def uncacheTable(tableName: String): Unit = {
|
||||
table(tableName).queryExecution.analyzed match {
|
||||
// This is kind of a hack to make sure that if this was just an RDD registered as a table,
|
||||
// we reregister the RDD as a table.
|
||||
case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) =>
|
||||
inMem.cachedColumnBuffers.unpersist()
|
||||
catalog.unregisterTable(None, tableName)
|
||||
catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self))
|
||||
case inMem: InMemoryRelation =>
|
||||
inMem.cachedColumnBuffers.unpersist()
|
||||
catalog.unregisterTable(None, tableName)
|
||||
case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns true if the table is currently cached in-memory. */
|
||||
def isCached(tableName: String): Boolean = {
|
||||
val relation = table(tableName).queryExecution.analyzed
|
||||
relation match {
|
||||
case _: InMemoryRelation => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
protected[sql] class SparkPlanner extends SparkStrategies {
|
||||
val sparkContext: SparkContext = self.sparkContext
|
||||
|
||||
|
@ -401,10 +364,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
|
||||
lazy val analyzed = ExtractPythonUdfs(analyzer(logical))
|
||||
lazy val optimizedPlan = optimizer(analyzed)
|
||||
lazy val withCachedData = useCachedData(optimizedPlan)
|
||||
|
||||
// TODO: Don't just pick the first one...
|
||||
lazy val sparkPlan = {
|
||||
SparkPlan.currentContext.set(self)
|
||||
planner(optimizedPlan).next()
|
||||
planner(withCachedData).next()
|
||||
}
|
||||
// executedPlan should not be used to initialize any SparkPlan. It should be
|
||||
// only used for execution.
|
||||
|
@ -526,6 +491,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
iter.map { m => new GenericRow(m): Row}
|
||||
}
|
||||
|
||||
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self))
|
||||
new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql
|
|||
|
||||
import java.util.{Map => JMap, List => JList}
|
||||
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
|
||||
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
|
||||
import org.apache.spark.sql.execution.LogicalRDD
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
|
||||
/**
|
||||
|
@ -442,8 +444,7 @@ class SchemaRDD(
|
|||
*/
|
||||
private def applySchema(rdd: RDD[Row]): SchemaRDD = {
|
||||
new SchemaRDD(sqlContext,
|
||||
SparkLogicalPlan(
|
||||
ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext))
|
||||
LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext))
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
|
@ -497,4 +498,20 @@ class SchemaRDD(
|
|||
override def subtract(other: RDD[Row], p: Partitioner)
|
||||
(implicit ord: Ordering[Row] = null): SchemaRDD =
|
||||
applySchema(super.subtract(other, p)(ord))
|
||||
|
||||
/** Overridden cache function will always use the in-memory columnar caching. */
|
||||
override def cache(): this.type = {
|
||||
sqlContext.cacheQuery(this)
|
||||
this
|
||||
}
|
||||
|
||||
override def persist(newLevel: StorageLevel): this.type = {
|
||||
sqlContext.cacheQuery(this, newLevel)
|
||||
this
|
||||
}
|
||||
|
||||
override def unpersist(blocking: Boolean): this.type = {
|
||||
sqlContext.uncacheQuery(this, blocking)
|
||||
this
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql
|
|||
import org.apache.spark.annotation.{DeveloperApi, Experimental}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.SparkLogicalPlan
|
||||
import org.apache.spark.sql.execution.LogicalRDD
|
||||
|
||||
/**
|
||||
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
|
||||
|
@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike {
|
|||
// For various commands (like DDL) and queries with side effects, we force query optimization to
|
||||
// happen right away to let these side effects take place eagerly.
|
||||
case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile =>
|
||||
queryExecution.toRdd
|
||||
SparkLogicalPlan(queryExecution.executedPlan)(sqlContext)
|
||||
LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
|
||||
case _ =>
|
||||
baseLogicalPlan
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD
|
|||
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
|
||||
import org.apache.spark.sql.parquet.ParquetRelation
|
||||
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
|
||||
import org.apache.spark.sql.execution.LogicalRDD
|
||||
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
|
||||
}
|
||||
}
|
||||
new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext))
|
||||
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
|
||||
val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType]
|
||||
val logicalPlan =
|
||||
SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext)
|
||||
LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext)
|
||||
new JavaSchemaRDD(sqlContext, logicalPlan)
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
|
||||
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
|
||||
val logicalPlan =
|
||||
SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
|
||||
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
|
||||
new JavaSchemaRDD(sqlContext, logicalPlan)
|
||||
}
|
||||
|
||||
|
@ -167,7 +167,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
|
|||
JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType]
|
||||
val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
|
||||
val logicalPlan =
|
||||
SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
|
||||
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
|
||||
new JavaSchemaRDD(sqlContext, logicalPlan)
|
||||
}
|
||||
|
||||
|
|
|
@ -27,10 +27,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
private[sql] object InMemoryRelation {
|
||||
def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation =
|
||||
new InMemoryRelation(child.output, useCompression, batchSize, child)()
|
||||
def apply(
|
||||
useCompression: Boolean,
|
||||
batchSize: Int,
|
||||
storageLevel: StorageLevel,
|
||||
child: SparkPlan): InMemoryRelation =
|
||||
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)()
|
||||
}
|
||||
|
||||
private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
|
||||
|
@ -39,6 +44,7 @@ private[sql] case class InMemoryRelation(
|
|||
output: Seq[Attribute],
|
||||
useCompression: Boolean,
|
||||
batchSize: Int,
|
||||
storageLevel: StorageLevel,
|
||||
child: SparkPlan)
|
||||
(private var _cachedColumnBuffers: RDD[CachedBatch] = null)
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation(
|
|||
// If the cached column buffers were not passed in, we calculate them in the constructor.
|
||||
// As in Spark, the actual work of caching is lazy.
|
||||
if (_cachedColumnBuffers == null) {
|
||||
buildBuffers()
|
||||
}
|
||||
|
||||
def recache() = {
|
||||
_cachedColumnBuffers.unpersist()
|
||||
_cachedColumnBuffers = null
|
||||
buildBuffers()
|
||||
}
|
||||
|
||||
private def buildBuffers(): Unit = {
|
||||
val output = child.output
|
||||
val cached = child.execute().mapPartitions { rowIterator =>
|
||||
new Iterator[CachedBatch] {
|
||||
|
@ -80,12 +96,17 @@ private[sql] case class InMemoryRelation(
|
|||
|
||||
def hasNext = rowIterator.hasNext
|
||||
}
|
||||
}.cache()
|
||||
}.persist(storageLevel)
|
||||
|
||||
cached.setName(child.toString)
|
||||
_cachedColumnBuffers = cached
|
||||
}
|
||||
|
||||
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
|
||||
InMemoryRelation(
|
||||
newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers)
|
||||
}
|
||||
|
||||
override def children = Seq.empty
|
||||
|
||||
override def newInstance() = {
|
||||
|
@ -93,6 +114,7 @@ private[sql] case class InMemoryRelation(
|
|||
output.map(_.newInstance),
|
||||
useCompression,
|
||||
batchSize,
|
||||
storageLevel,
|
||||
child)(
|
||||
_cachedColumnBuffers).asInstanceOf[this.type]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{SQLContext, Row}
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
*/
|
||||
@DeveloperApi
|
||||
object RDDConversions {
|
||||
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
|
||||
data.mapPartitions { iterator =>
|
||||
if (iterator.isEmpty) {
|
||||
Iterator.empty
|
||||
} else {
|
||||
val bufferedIterator = iterator.buffered
|
||||
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
|
||||
|
||||
bufferedIterator.map { r =>
|
||||
var i = 0
|
||||
while (i < mutableRow.length) {
|
||||
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
|
||||
i += 1
|
||||
}
|
||||
|
||||
mutableRow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = {
|
||||
LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
||||
def children = Nil
|
||||
|
||||
def newInstance() =
|
||||
LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
|
||||
|
||||
override def sameResult(plan: LogicalPlan) = plan match {
|
||||
case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
|
||||
case _ => false
|
||||
}
|
||||
|
||||
@transient override lazy val statistics = Statistics(
|
||||
// TODO: Instead of returning a default value here, find a way to return a meaningful size
|
||||
// estimate for RDDs. See PR 1238 for more discussions.
|
||||
sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
|
||||
)
|
||||
}
|
||||
|
||||
case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
|
||||
override def execute() = rdd
|
||||
}
|
||||
|
||||
@deprecated("Use LogicalRDD", "1.2.0")
|
||||
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
|
||||
override def execute() = rdd
|
||||
}
|
||||
|
||||
@deprecated("Use LogicalRDD", "1.2.0")
|
||||
case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
||||
def output = alreadyPlanned.output
|
||||
override def children = Nil
|
||||
|
||||
override final def newInstance(): this.type = {
|
||||
SparkLogicalPlan(
|
||||
alreadyPlanned match {
|
||||
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
|
||||
case _ => sys.error("Multiple instance of the same relation detected.")
|
||||
})(sqlContext).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
override def sameResult(plan: LogicalPlan) = plan match {
|
||||
case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
|
||||
rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
|
||||
case _ => false
|
||||
}
|
||||
|
||||
@transient override lazy val statistics = Statistics(
|
||||
// TODO: Instead of returning a default value here, find a way to return a meaningful size
|
||||
// estimate for RDDs. See PR 1238 for more discussions.
|
||||
sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
|
||||
)
|
||||
}
|
|
@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Allows already planned SparkQueries to be linked into logical query plans.
|
||||
*
|
||||
* Note that in general it is not valid to use this class to link multiple copies of the same
|
||||
* physical operator into the same query plan as this violates the uniqueness of expression ids.
|
||||
* Special handling exists for ExistingRdd as these are already leaf operators and thus we can just
|
||||
* replace the output attributes with new copies of themselves without breaking any attribute
|
||||
* linking.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
|
||||
extends LogicalPlan with MultiInstanceRelation {
|
||||
|
||||
def output = alreadyPlanned.output
|
||||
override def children = Nil
|
||||
|
||||
override final def newInstance(): this.type = {
|
||||
SparkLogicalPlan(
|
||||
alreadyPlanned match {
|
||||
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
|
||||
case _ => sys.error("Multiple instance of the same relation detected.")
|
||||
})(sqlContext).asInstanceOf[this.type]
|
||||
}
|
||||
|
||||
@transient override lazy val statistics = Statistics(
|
||||
// TODO: Instead of returning a default value here, find a way to return a meaningful size
|
||||
// estimate for RDDs. See PR 1238 for more discussions.
|
||||
sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
|
||||
self: Product =>
|
||||
}
|
||||
|
|
|
@ -272,10 +272,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
|
||||
case logical.Sample(fraction, withReplacement, seed, child) =>
|
||||
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
|
||||
case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
|
||||
case logical.LocalRelation(output, data) =>
|
||||
ExistingRdd(
|
||||
PhysicalRDD(
|
||||
output,
|
||||
ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
|
||||
RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
|
||||
case logical.Limit(IntegerLiteral(limit), child) =>
|
||||
execution.Limit(limit, planLater(child)) :: Nil
|
||||
case Unions(unionChildren) =>
|
||||
|
@ -287,12 +288,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case logical.Generate(generator, join, outer, _, child) =>
|
||||
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
|
||||
case logical.NoRelation =>
|
||||
execution.ExistingRdd(Nil, singleRowRdd) :: Nil
|
||||
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
|
||||
case logical.Repartition(expressions, child) =>
|
||||
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
|
||||
case e @ EvaluatePython(udf, child) =>
|
||||
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
|
||||
case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
|
||||
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -210,45 +210,6 @@ case class Sort(
|
|||
override def output = child.output
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
*/
|
||||
@DeveloperApi
|
||||
object ExistingRdd {
|
||||
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
|
||||
data.mapPartitions { iterator =>
|
||||
if (iterator.isEmpty) {
|
||||
Iterator.empty
|
||||
} else {
|
||||
val bufferedIterator = iterator.buffered
|
||||
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
|
||||
|
||||
bufferedIterator.map { r =>
|
||||
var i = 0
|
||||
while (i < mutableRow.length) {
|
||||
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
|
||||
i += 1
|
||||
}
|
||||
|
||||
mutableRow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {
|
||||
ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
|
||||
override def execute() = rdd
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Computes the set of distinct input rows using a HashSet.
|
||||
|
|
|
@ -20,13 +20,30 @@ package org.apache.spark.sql
|
|||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
||||
case class BigData(s: String)
|
||||
|
||||
class CachedTableSuite extends QueryTest {
|
||||
import TestSQLContext._
|
||||
TestData // Load test tables.
|
||||
|
||||
/**
|
||||
* Throws a test failed exception when the number of cached tables differs from the expected
|
||||
* number.
|
||||
*/
|
||||
def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
|
||||
val planWithCaching = query.queryExecution.withCachedData
|
||||
val cachedData = planWithCaching collect {
|
||||
case cached: InMemoryRelation => cached
|
||||
}
|
||||
|
||||
if (cachedData.size != numCachedTables) {
|
||||
fail(
|
||||
s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
|
||||
planWithCaching)
|
||||
}
|
||||
}
|
||||
|
||||
test("too big for memory") {
|
||||
val data = "*" * 10000
|
||||
sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData")
|
||||
|
@ -35,19 +52,21 @@ class CachedTableSuite extends QueryTest {
|
|||
uncacheTable("bigData")
|
||||
}
|
||||
|
||||
test("calling .cache() should use inmemory columnar caching") {
|
||||
table("testData").cache()
|
||||
|
||||
assertCached(table("testData"))
|
||||
}
|
||||
|
||||
test("SPARK-1669: cacheTable should be idempotent") {
|
||||
assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
|
||||
|
||||
cacheTable("testData")
|
||||
table("testData").queryExecution.analyzed match {
|
||||
case _: InMemoryRelation =>
|
||||
case _ =>
|
||||
fail("testData should be cached")
|
||||
}
|
||||
assertCached(table("testData"))
|
||||
|
||||
cacheTable("testData")
|
||||
table("testData").queryExecution.analyzed match {
|
||||
case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) =>
|
||||
case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) =>
|
||||
fail("cacheTable is not idempotent")
|
||||
|
||||
case _ =>
|
||||
|
@ -55,81 +74,69 @@ class CachedTableSuite extends QueryTest {
|
|||
}
|
||||
|
||||
test("read from cached table and uncache") {
|
||||
TestSQLContext.cacheTable("testData")
|
||||
cacheTable("testData")
|
||||
|
||||
checkAnswer(
|
||||
TestSQLContext.table("testData"),
|
||||
table("testData"),
|
||||
testData.collect().toSeq
|
||||
)
|
||||
|
||||
TestSQLContext.table("testData").queryExecution.analyzed match {
|
||||
case _ : InMemoryRelation => // Found evidence of caching
|
||||
case noCache => fail(s"No cache node found in plan $noCache")
|
||||
}
|
||||
assertCached(table("testData"))
|
||||
|
||||
TestSQLContext.uncacheTable("testData")
|
||||
uncacheTable("testData")
|
||||
|
||||
checkAnswer(
|
||||
TestSQLContext.table("testData"),
|
||||
table("testData"),
|
||||
testData.collect().toSeq
|
||||
)
|
||||
|
||||
TestSQLContext.table("testData").queryExecution.analyzed match {
|
||||
case cachePlan: InMemoryRelation =>
|
||||
fail(s"Table still cached after uncache: $cachePlan")
|
||||
case noCache => // Table uncached successfully
|
||||
}
|
||||
assertCached(table("testData"), 0)
|
||||
}
|
||||
|
||||
test("correct error on uncache of non-cached table") {
|
||||
intercept[IllegalArgumentException] {
|
||||
TestSQLContext.uncacheTable("testData")
|
||||
uncacheTable("testData")
|
||||
}
|
||||
}
|
||||
|
||||
test("SELECT Star Cached Table") {
|
||||
TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar")
|
||||
TestSQLContext.cacheTable("selectStar")
|
||||
TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect()
|
||||
TestSQLContext.uncacheTable("selectStar")
|
||||
sql("SELECT * FROM testData").registerTempTable("selectStar")
|
||||
cacheTable("selectStar")
|
||||
sql("SELECT * FROM selectStar WHERE key = 1").collect()
|
||||
uncacheTable("selectStar")
|
||||
}
|
||||
|
||||
test("Self-join cached") {
|
||||
val unCachedAnswer =
|
||||
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
|
||||
TestSQLContext.cacheTable("testData")
|
||||
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
|
||||
cacheTable("testData")
|
||||
checkAnswer(
|
||||
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
|
||||
sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
|
||||
unCachedAnswer.toSeq)
|
||||
TestSQLContext.uncacheTable("testData")
|
||||
uncacheTable("testData")
|
||||
}
|
||||
|
||||
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
|
||||
TestSQLContext.sql("CACHE TABLE testData")
|
||||
TestSQLContext.table("testData").queryExecution.executedPlan match {
|
||||
case _: InMemoryColumnarTableScan => // Found evidence of caching
|
||||
case _ => fail(s"Table 'testData' should be cached")
|
||||
}
|
||||
assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached")
|
||||
sql("CACHE TABLE testData")
|
||||
assertCached(table("testData"))
|
||||
|
||||
TestSQLContext.sql("UNCACHE TABLE testData")
|
||||
TestSQLContext.table("testData").queryExecution.executedPlan match {
|
||||
case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached")
|
||||
case _ => // Found evidence of uncaching
|
||||
}
|
||||
assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
|
||||
assert(isCached("testData"), "Table 'testData' should be cached")
|
||||
|
||||
sql("UNCACHE TABLE testData")
|
||||
assertCached(table("testData"), 0)
|
||||
assert(!isCached("testData"), "Table 'testData' should not be cached")
|
||||
}
|
||||
|
||||
test("CACHE TABLE tableName AS SELECT Star Table") {
|
||||
TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
|
||||
TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
|
||||
assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
|
||||
TestSQLContext.uncacheTable("testCacheTable")
|
||||
sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
|
||||
sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
|
||||
assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
|
||||
uncacheTable("testCacheTable")
|
||||
}
|
||||
|
||||
test("'CACHE TABLE tableName AS SELECT ..'") {
|
||||
TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
|
||||
assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
|
||||
TestSQLContext.uncacheTable("testCacheTable")
|
||||
sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
|
||||
assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
|
||||
uncacheTable("testCacheTable")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar
|
|||
import org.apache.spark.sql.catalyst.expressions.Row
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.{QueryTest, TestData}
|
||||
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
|
||||
|
||||
class InMemoryColumnarQuerySuite extends QueryTest {
|
||||
import org.apache.spark.sql.TestData._
|
||||
|
@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
|
|||
|
||||
test("simple columnar query") {
|
||||
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
|
||||
val scan = InMemoryRelation(useCompression = true, 5, plan)
|
||||
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
|
||||
|
||||
checkAnswer(scan, testData.collect().toSeq)
|
||||
}
|
||||
|
@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
|
|||
|
||||
test("projection") {
|
||||
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
|
||||
val scan = InMemoryRelation(useCompression = true, 5, plan)
|
||||
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
|
||||
|
||||
checkAnswer(scan, testData.collect().map {
|
||||
case Row(key: Int, value: String) => value -> key
|
||||
|
@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
|
|||
|
||||
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
|
||||
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
|
||||
val scan = InMemoryRelation(useCompression = true, 5, plan)
|
||||
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
|
||||
|
||||
checkAnswer(scan, testData.collect().toSeq)
|
||||
checkAnswer(scan, testData.collect().toSeq)
|
||||
|
|
|
@ -133,11 +133,6 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
|
|||
|
||||
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
|
||||
castChildOutput(p, table, child)
|
||||
|
||||
case p @ logical.InsertIntoTable(
|
||||
InMemoryRelation(_, _, _,
|
||||
HiveTableScan(_, table, _)), _, child, _) =>
|
||||
castChildOutput(p, table, child)
|
||||
}
|
||||
|
||||
def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = {
|
||||
|
@ -306,7 +301,7 @@ private[hive] case class MetastoreRelation
|
|||
HiveMetastoreTypes.toDataType(f.getType),
|
||||
// Since data can be dumped in randomly with no validation, everything is nullable.
|
||||
nullable = true
|
||||
)(qualifiers = tableName +: alias.toSeq)
|
||||
)(qualifiers = Seq(alias.getOrElse(tableName)))
|
||||
}
|
||||
|
||||
// Must be a stable value since new attributes are born here.
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._
|
|||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.types.StringType
|
||||
import org.apache.spark.sql.columnar.InMemoryRelation
|
||||
import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
|
||||
import org.apache.spark.sql.hive
|
||||
import org.apache.spark.sql.hive.execution._
|
||||
|
@ -161,10 +160,7 @@ private[hive] trait HiveStrategies {
|
|||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
|
||||
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
|
||||
case logical.InsertIntoTable(
|
||||
InMemoryRelation(_, _, _,
|
||||
HiveTableScan(_, table, _)), partition, child, overwrite) =>
|
||||
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
|
||||
|
||||
case logical.CreateTableAsSelect(database, tableName, child) =>
|
||||
val query = planLater(child)
|
||||
CreateTableAsSelect(
|
||||
|
|
|
@ -353,7 +353,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
|
|||
var cacheTables: Boolean = false
|
||||
def loadTestTable(name: String) {
|
||||
if (!(loadedTables contains name)) {
|
||||
// Marks the table as loaded first to prevent infite mutually recursive table loading.
|
||||
// Marks the table as loaded first to prevent infinite mutually recursive table loading.
|
||||
loadedTables += name
|
||||
logInfo(s"Loading test table $name")
|
||||
val createCmds =
|
||||
|
@ -383,6 +383,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
|
|||
log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
|
||||
}
|
||||
|
||||
clearCache()
|
||||
loadedTables.clear()
|
||||
catalog.client.getAllTables("default").foreach { t =>
|
||||
logDebug(s"Deleting table $t")
|
||||
|
@ -428,7 +429,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
|
|||
loadTestTable("srcpart")
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"FATAL ERROR: Failed to reset TestDB state. $e")
|
||||
logError("FATAL ERROR: Failed to reset TestDB state.", e)
|
||||
// At this point there is really no reason to continue, but the test framework traps exits.
|
||||
// So instead we just pause forever so that at least the developer can see where things
|
||||
// started to go wrong.
|
||||
|
|
|
@ -267,6 +267,9 @@ case class InsertIntoHiveTable(
|
|||
holdDDLTime)
|
||||
}
|
||||
|
||||
// Invalidate the cache.
|
||||
sqlContext.invalidateCache(table)
|
||||
|
||||
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
|
||||
// however for now we return an empty list to simplify compatibility checks with hive, which
|
||||
// does not return anything for insert operations.
|
||||
|
|
|
@ -17,22 +17,60 @@
|
|||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import org.apache.spark.sql.execution.SparkLogicalPlan
|
||||
import org.apache.spark.sql.{QueryTest, SchemaRDD}
|
||||
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
|
||||
import org.apache.spark.sql.hive.execution.HiveComparisonTest
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
|
||||
class CachedTableSuite extends HiveComparisonTest {
|
||||
class CachedTableSuite extends QueryTest {
|
||||
import TestHive._
|
||||
|
||||
TestHive.loadTestTable("src")
|
||||
/**
|
||||
* Throws a test failed exception when the number of cached tables differs from the expected
|
||||
* number.
|
||||
*/
|
||||
def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
|
||||
val planWithCaching = query.queryExecution.withCachedData
|
||||
val cachedData = planWithCaching collect {
|
||||
case cached: InMemoryRelation => cached
|
||||
}
|
||||
|
||||
test("cache table") {
|
||||
TestHive.cacheTable("src")
|
||||
if (cachedData.size != numCachedTables) {
|
||||
fail(
|
||||
s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
|
||||
planWithCaching)
|
||||
}
|
||||
}
|
||||
|
||||
createQueryTest("read from cached table",
|
||||
"SELECT * FROM src LIMIT 1", reset = false)
|
||||
test("cache table") {
|
||||
val preCacheResults = sql("SELECT * FROM src").collect().toSeq
|
||||
|
||||
cacheTable("src")
|
||||
assertCached(sql("SELECT * FROM src"))
|
||||
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM src"),
|
||||
preCacheResults)
|
||||
|
||||
uncacheTable("src")
|
||||
assertCached(sql("SELECT * FROM src"), 0)
|
||||
}
|
||||
|
||||
test("cache invalidation") {
|
||||
sql("CREATE TABLE cachedTable(key INT, value STRING)")
|
||||
|
||||
sql("INSERT INTO TABLE cachedTable SELECT * FROM src")
|
||||
checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq)
|
||||
|
||||
cacheTable("cachedTable")
|
||||
checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq)
|
||||
|
||||
sql("INSERT INTO TABLE cachedTable SELECT * FROM src")
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM cachedTable"),
|
||||
table("src").collect().toSeq ++ table("src").collect().toSeq)
|
||||
|
||||
sql("DROP TABLE cachedTable")
|
||||
}
|
||||
|
||||
test("Drop cached table") {
|
||||
sql("CREATE TABLE test(a INT)")
|
||||
|
@ -48,25 +86,6 @@ class CachedTableSuite extends HiveComparisonTest {
|
|||
sql("DROP TABLE IF EXISTS nonexistantTable")
|
||||
}
|
||||
|
||||
test("check that table is cached and uncache") {
|
||||
TestHive.table("src").queryExecution.analyzed match {
|
||||
case _ : InMemoryRelation => // Found evidence of caching
|
||||
case noCache => fail(s"No cache node found in plan $noCache")
|
||||
}
|
||||
TestHive.uncacheTable("src")
|
||||
}
|
||||
|
||||
createQueryTest("read from uncached table",
|
||||
"SELECT * FROM src LIMIT 1", reset = false)
|
||||
|
||||
test("make sure table is uncached") {
|
||||
TestHive.table("src").queryExecution.analyzed match {
|
||||
case cachePlan: InMemoryRelation =>
|
||||
fail(s"Table still cached after uncache: $cachePlan")
|
||||
case noCache => // Table uncached successfully
|
||||
}
|
||||
}
|
||||
|
||||
test("correct error on uncache of non-cached table") {
|
||||
intercept[IllegalArgumentException] {
|
||||
TestHive.uncacheTable("src")
|
||||
|
@ -75,23 +94,24 @@ class CachedTableSuite extends HiveComparisonTest {
|
|||
|
||||
test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") {
|
||||
TestHive.sql("CACHE TABLE src")
|
||||
TestHive.table("src").queryExecution.executedPlan match {
|
||||
case _: InMemoryColumnarTableScan => // Found evidence of caching
|
||||
case _ => fail(s"Table 'src' should be cached")
|
||||
}
|
||||
assertCached(table("src"))
|
||||
assert(TestHive.isCached("src"), "Table 'src' should be cached")
|
||||
|
||||
TestHive.sql("UNCACHE TABLE src")
|
||||
TestHive.table("src").queryExecution.executedPlan match {
|
||||
case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached")
|
||||
case _ => // Found evidence of uncaching
|
||||
}
|
||||
assertCached(table("src"), 0)
|
||||
assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
|
||||
}
|
||||
|
||||
test("'CACHE TABLE tableName AS SELECT ..'") {
|
||||
TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src")
|
||||
assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
|
||||
TestHive.uncacheTable("testCacheTable")
|
||||
}
|
||||
|
||||
test("CACHE TABLE AS SELECT") {
|
||||
assertCached(sql("SELECT * FROM src"), 0)
|
||||
sql("CACHE TABLE test AS SELECT key FROM src")
|
||||
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM test"),
|
||||
sql("SELECT key FROM src").collect().toSeq)
|
||||
|
||||
assertCached(sql("SELECT * FROM test"))
|
||||
|
||||
assertCached(sql("SELECT * FROM test JOIN test"), 2)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue