[SPARK-27816][SQL] make TreeNode tag type safe

## What changes were proposed in this pull request?

Add type parameter to `TreeNodeTag`.

## How was this patch tested?

existing tests

Closes #24687 from cloud-fan/tag.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
Wenchen Fan 2019-05-23 11:53:21 -07:00 committed by gatorsmile
parent 74e5e41eeb
commit 1a68fc38f0
6 changed files with 36 additions and 23 deletions

View file

@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode {
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
val newCopy = OneRowRelation()
newCopy.tags ++= this.tags
newCopy.copyTagsFrom(this)
newCopy
}
}

View file

@ -74,9 +74,8 @@ object CurrentOrigin {
}
}
// The name of the tree node tag. This is preferred over using string directly, as we can easily
// find all the defined tags.
case class TreeNodeTagName(name: String)
// A tag of a `TreeNode`, which defines name and type
case class TreeNodeTag[T](name: String)
// scalastyle:off
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* A mutable map for holding auxiliary information of this tree node. It will be carried over
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
*/
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty
protected def copyTagsFrom(other: BaseType): Unit = {
tags ++= other.tags
}
def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
tags(tag) = value
}
def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
tags.get(tag).map(_.asInstanceOf[T])
}
/**
* Returns a Seq of the children of this node.
@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
try {
CurrentOrigin.withOrigin(origin) {
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res.tags ++= this.tags
res.copyTagsFrom(this)
res
}
} catch {

View file

@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}
test("tags will be carried over after copy & transform") {
val tag = TreeNodeTag[String]("test")
withClue("makeCopy") {
val node = Dummy(None)
node.tags += TreeNodeTagName("test") -> "a"
node.setTagValue(tag, "a")
val copied = node.makeCopy(Array(Some(Literal(1))))
assert(copied.tags(TreeNodeTagName("test")) == "a")
assert(copied.getTagValue(tag) == Some("a"))
}
def checkTransform(
sameTypeTransform: Expression => Expression,
differentTypeTransform: Expression => Expression): Unit = {
val child = Dummy(None)
child.tags += TreeNodeTagName("test") -> "child"
child.setTagValue(tag, "child")
val node = Dummy(Some(child))
node.tags += TreeNodeTagName("test") -> "parent"
node.setTagValue(tag, "parent")
val transformed = sameTypeTransform(node)
// Both the child and parent keep the tags
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
assert(transformed.getTagValue(tag) == Some("parent"))
assert(transformed.children.head.getTagValue(tag) == Some("child"))
val transformed2 = differentTypeTransform(node)
// Both the child and parent keep the tags, even if we transform the node to a new one of
// different type.
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
assert(transformed2.getTagValue(tag) == Some("parent"))
assert(transformed2.children.head.getTagValue(tag) == Some("child"))
}
withClue("transformDown") {

View file

@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
object SparkPlan {
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
// when converting a logical plan to a physical plan.
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
}
/**

View file

@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ReturnAnswer(rootPlan) => rootPlan
case _ => plan
}
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
p
}
}

View file

@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.TPCDSQuerySuite
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
// The exchange related nodes are created after the planning, they don't have corresponding
// logical plan.
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
// The subquery exec nodes are just wrappers of the actual nodes, they don't have
// corresponding logical plan.
case _: SubqueryExec | _: ReusedSubqueryExec =>
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
case _ if isScanPlanTree(plan) =>
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
}
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
node.getClass.getSimpleName + " does not have a logical plan link")
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
fail(node.getClass.getSimpleName + " does not have a logical plan link")
}
}
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {