From 1a68fc38f0aafb9015c499b3f9f7fbe63739e909 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 May 2019 11:53:21 -0700 Subject: [PATCH] [SPARK-27816][SQL] make TreeNode tag type safe ## What changes were proposed in this pull request? Add type parameter to `TreeNodeTag`. ## How was this patch tested? existing tests Closes #24687 from cloud-fan/tag. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/catalyst/trees/TreeNode.scala | 21 ++++++++++++++----- .../sql/catalyst/trees/TreeNodeSuite.scala | 18 +++++++++------- .../spark/sql/execution/SparkPlan.scala | 5 +++-- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../LogicalPlanTagInSparkPlanSuite.scala | 11 +++++----- 6 files changed, 36 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a2a7eb1ea5..4350f91c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode { /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { val newCopy = OneRowRelation() - newCopy.tags ++= this.tags + newCopy.copyTagsFrom(this) newCopy } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a5705d0f32..cd5dfb7628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -74,9 +74,8 @@ object CurrentOrigin { } } -// The name of the tree node tag. This is preferred over using string directly, as we can easily -// find all the defined tags. -case class TreeNodeTagName(name: String) +// A tag of a `TreeNode`, which defines name and type +case class TreeNodeTag[T](name: String) // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { @@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * A mutable map for holding auxiliary information of this tree node. It will be carried over * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. */ - val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty + private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty + + protected def copyTagsFrom(other: BaseType): Unit = { + tags ++= other.tags + } + + def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = { + tags(tag) = value + } + + def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = { + tags.get(tag).map(_.asInstanceOf[T]) + } /** * Returns a Seq of the children of this node. @@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] - res.tags ++= this.tags + res.copyTagsFrom(this) res } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 5cfa84d230..744d522b1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { } test("tags will be carried over after copy & transform") { + val tag = TreeNodeTag[String]("test") + withClue("makeCopy") { val node = Dummy(None) - node.tags += TreeNodeTagName("test") -> "a" + node.setTagValue(tag, "a") val copied = node.makeCopy(Array(Some(Literal(1)))) - assert(copied.tags(TreeNodeTagName("test")) == "a") + assert(copied.getTagValue(tag) == Some("a")) } def checkTransform( sameTypeTransform: Expression => Expression, differentTypeTransform: Expression => Expression): Unit = { val child = Dummy(None) - child.tags += TreeNodeTagName("test") -> "child" + child.setTagValue(tag, "child") val node = Dummy(Some(child)) - node.tags += TreeNodeTagName("test") -> "parent" + node.setTagValue(tag, "parent") val transformed = sameTypeTransform(node) // Both the child and parent keep the tags - assert(transformed.tags(TreeNodeTagName("test")) == "parent") - assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") + assert(transformed.getTagValue(tag) == Some("parent")) + assert(transformed.children.head.getTagValue(tag) == Some("child")) val transformed2 = differentTypeTransform(node) // Both the child and parent keep the tags, even if we transform the node to a new one of // different type. - assert(transformed2.tags(TreeNodeTagName("test")) == "parent") - assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + assert(transformed2.getTagValue(tag) == Some("parent")) + assert(transformed2.children.head.getTagValue(tag) == Some("child")) } withClue("transformDown") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 307a01a50e..ddcf61b882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeTagName +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType object SparkPlan { // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag // when converting a logical plan to a physical plan. - val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") + val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c9db78b3ed..c4031496f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ReturnAnswer(rootPlan) => rootPlan case _ => plan } - p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan) p } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index ca7ced5ef5..b35348b4ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark.sql.TPCDSQuerySuite import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { // The exchange related nodes are created after the planning, they don't have corresponding // logical plan. case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => - assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) // The subquery exec nodes are just wrappers of the actual nodes, they don't have // corresponding logical plan. case _: SubqueryExec | _: ReusedSubqueryExec => - assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) case _ if isScanPlanTree(plan) => // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, @@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { } private def getLogicalPlan(node: SparkPlan): LogicalPlan = { - assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME), - node.getClass.getSimpleName + " does not have a logical plan link") - node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { + fail(node.getClass.getSimpleName + " does not have a logical plan link") + } } private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {