[SPARK-20897][SQL] cached self-join should not fail
## What changes were proposed in this pull request? The failed test case is, we have a `SortMergeJoinExec` for a self-join, which means we have a `ReusedExchange` node in the query plan. It works fine without caching, but throws an exception in `SortMergeJoinExec.outputPartitioning` if we cache it. The root cause is, `ReusedExchange` doesn't propagate the output partitioning from its child, so in `SortMergeJoinExec.outputPartitioning` we create `PartitioningCollection` with a hash partitioning and an unknown partitioning, and fail. This bug is mostly fine, because inserting the `ReusedExchange` is the last step to prepare the physical plan, we won't call `SortMergeJoinExec.outputPartitioning` anymore after this. However, if the dataframe is cached, the physical plan of it becomes `InMemoryTableScanExec`, which contains another physical plan representing the cached query, and it has gone through the entire planning phase and may have `ReusedExchange`. Then the planner call `InMemoryTableScanExec.outputPartitioning`, which then calls `SortMergeJoinExec.outputPartitioning` and trigger this bug. ## How was this patch tested? a new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #18121 from cloud-fan/bug.
This commit is contained in:
parent
8faffc4167
commit
08ede46b89
|
@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer
|
||||||
import org.apache.spark.broadcast
|
import org.apache.spark.broadcast
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
|
||||||
|
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
import org.apache.spark.sql.catalyst.rules.Rule
|
||||||
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
|
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
|
||||||
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
|
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
|
||||||
child.executeBroadcast()
|
child.executeBroadcast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need
|
||||||
|
// to update the attribute ids in `outputPartitioning` and `outputOrdering`.
|
||||||
|
private lazy val updateAttr: Expression => Expression = {
|
||||||
|
val originalAttrToNewAttr = AttributeMap(child.output.zip(output))
|
||||||
|
e => e.transform {
|
||||||
|
case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def outputPartitioning: Partitioning = child.outputPartitioning match {
|
||||||
|
case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr))
|
||||||
|
case other => other
|
||||||
|
}
|
||||||
|
|
||||||
|
override def outputOrdering: Seq[SortOrder] = {
|
||||||
|
child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1855,4 +1855,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
||||||
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
|
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
|
||||||
df.filter(filter).count
|
df.filter(filter).count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-20897: cached self-join should not fail") {
|
||||||
|
// force to plan sort merge join
|
||||||
|
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
|
||||||
|
val df = Seq(1 -> "a").toDF("i", "j")
|
||||||
|
val df1 = df.as("t1")
|
||||||
|
val df2 = df.as("t2")
|
||||||
|
assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue