diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index d993ea6c6c..4b52f3e4c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { child.executeBroadcast() } + + // `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(child.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2f52192b54..9f691cb10f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1855,4 +1855,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) df.filter(filter).count } + + test("SPARK-20897: cached self-join should not fail") { + // force to plan sort merge join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df = Seq(1 -> "a").toDF("i", "j") + val df1 = df.as("t1") + val df2 = df.as("t2") + assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) + } + } }