[SPARK-5354][SQL] Cached tables should preserve partitioning and ord…

…ering.

For cached tables, we can just maintain the partitioning and ordering from the
source relation.

Author: Nong Li <nongli@gmail.com>

Closes #9404 from nongli/spark-5354.
This commit is contained in:
Nong Li 2015-11-02 19:18:45 -08:00 committed by Yin Huai
parent 21ad846238
commit 2cef1bb0b5
3 changed files with 97 additions and 9 deletions

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.storage.StorageLevel
@ -209,6 +210,12 @@ private[sql] case class InMemoryColumnarTableScan(
override def output: Seq[Attribute] = attributes
// The cached version does not change the outputPartitioning of the original SparkPlan.
override def outputPartitioning: Partitioning = relation.child.outputPartitioning
// The cached version does not change the outputOrdering of the original SparkPlan.
override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
override def outputsUnsafeRows: Boolean = true
private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)

View file

@ -194,12 +194,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
private def numPartitions: Int = sqlContext.conf.numShufflePartitions
private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions
/**
* Given a required distribution, returns a partitioning that satisfies that distribution.
*/
private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
private def createPartitioning(requiredDistribution: Distribution,
numPartitions: Int): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
@ -220,7 +221,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
if (child.outputPartitioning.satisfies(distribution)) {
child
} else {
Exchange(canonicalPartitioning(distribution), child)
Exchange(createPartitioning(distribution, defaultPartitions), child)
}
}
@ -229,12 +230,33 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
if (children.length > 1
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
val targetPartitioning = canonicalPartitioning(distribution)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Exchange(targetPartitioning, child)
// First check if the existing partitions of the children all match. This means they are
// partitioned by the same partitioning into the same number of partitions. In that case,
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
// TODO: this should be a cost based descision. For example, a big relation should probably
// maintain its existing number of partitions and smaller partitions should be shuffled.
// defaultPartitions is arbitrary.
val numPartitions = children.head.outputPartitioning.numPartitions
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
case (child, distribution) => {
child.outputPartitioning.guarantees(
createPartitioning(distribution, numPartitions))
}
}
children = if (useExistingPartitioning) {
children
} else {
children.zip(requiredChildDistributions).map {
case (child, distribution) => {
val targetPartitioning = createPartitioning(distribution, defaultPartitions)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Exchange(targetPartitioning, child)
}
}
}
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.PhysicalRDD
import scala.concurrent.duration._
@ -353,4 +354,62 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3)
assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0)
}
/**
* Verifies that the plan for `df` contains `expected` number of Exchange operators.
*/
private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
val table3x = testData.unionAll(testData).unionAll(testData)
table3x.registerTempTable("testData3x")
sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
sqlContext.cacheTable("orderedTable")
assertCached(sqlContext.table("orderedTable"))
// Should not have an exchange as the query is already sorted on the group by key.
verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0)
checkAnswer(
sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
sqlContext.uncacheTable("orderedTable")
// Set up two tables distributed in the same way. Try this with the data distributed into
// different number of partitions.
for (numPartitions <- 1 until 10 by 4) {
testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1")
testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")
// Joining them should result in no exchanges.
verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
// Grouping on the partition key should result in no exchanges
verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
sql("SELECT count(*) FROM testData GROUP BY key"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")
}
// Distribute the tables into non-matching number of partitions. Need to shuffle.
testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1")
testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")
verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
sqlContext.dropTempTable("t1")
sqlContext.dropTempTable("t2")
}
}