[SPARK-27816][SQL] make TreeNode tag type safe
## What changes were proposed in this pull request? Add type parameter to `TreeNodeTag`. ## How was this patch tested? existing tests Closes #24687 from cloud-fan/tag. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
parent
74e5e41eeb
commit
1a68fc38f0
|
@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode {
|
|||
/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
|
||||
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
|
||||
val newCopy = OneRowRelation()
|
||||
newCopy.tags ++= this.tags
|
||||
newCopy.copyTagsFrom(this)
|
||||
newCopy
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,9 +74,8 @@ object CurrentOrigin {
|
|||
}
|
||||
}
|
||||
|
||||
// The name of the tree node tag. This is preferred over using string directly, as we can easily
|
||||
// find all the defined tags.
|
||||
case class TreeNodeTagName(name: String)
|
||||
// A tag of a `TreeNode`, which defines name and type
|
||||
case class TreeNodeTag[T](name: String)
|
||||
|
||||
// scalastyle:off
|
||||
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
|
||||
|
@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
|
|||
* A mutable map for holding auxiliary information of this tree node. It will be carried over
|
||||
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
|
||||
*/
|
||||
val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
|
||||
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty
|
||||
|
||||
protected def copyTagsFrom(other: BaseType): Unit = {
|
||||
tags ++= other.tags
|
||||
}
|
||||
|
||||
def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
|
||||
tags(tag) = value
|
||||
}
|
||||
|
||||
def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
|
||||
tags.get(tag).map(_.asInstanceOf[T])
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Seq of the children of this node.
|
||||
|
@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
|
|||
try {
|
||||
CurrentOrigin.withOrigin(origin) {
|
||||
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
|
||||
res.tags ++= this.tags
|
||||
res.copyTagsFrom(this)
|
||||
res
|
||||
}
|
||||
} catch {
|
||||
|
|
|
@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
|
|||
}
|
||||
|
||||
test("tags will be carried over after copy & transform") {
|
||||
val tag = TreeNodeTag[String]("test")
|
||||
|
||||
withClue("makeCopy") {
|
||||
val node = Dummy(None)
|
||||
node.tags += TreeNodeTagName("test") -> "a"
|
||||
node.setTagValue(tag, "a")
|
||||
val copied = node.makeCopy(Array(Some(Literal(1))))
|
||||
assert(copied.tags(TreeNodeTagName("test")) == "a")
|
||||
assert(copied.getTagValue(tag) == Some("a"))
|
||||
}
|
||||
|
||||
def checkTransform(
|
||||
sameTypeTransform: Expression => Expression,
|
||||
differentTypeTransform: Expression => Expression): Unit = {
|
||||
val child = Dummy(None)
|
||||
child.tags += TreeNodeTagName("test") -> "child"
|
||||
child.setTagValue(tag, "child")
|
||||
val node = Dummy(Some(child))
|
||||
node.tags += TreeNodeTagName("test") -> "parent"
|
||||
node.setTagValue(tag, "parent")
|
||||
|
||||
val transformed = sameTypeTransform(node)
|
||||
// Both the child and parent keep the tags
|
||||
assert(transformed.tags(TreeNodeTagName("test")) == "parent")
|
||||
assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child")
|
||||
assert(transformed.getTagValue(tag) == Some("parent"))
|
||||
assert(transformed.children.head.getTagValue(tag) == Some("child"))
|
||||
|
||||
val transformed2 = differentTypeTransform(node)
|
||||
// Both the child and parent keep the tags, even if we transform the node to a new one of
|
||||
// different type.
|
||||
assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
|
||||
assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
|
||||
assert(transformed2.getTagValue(tag) == Some("parent"))
|
||||
assert(transformed2.children.head.getTagValue(tag) == Some("child"))
|
||||
}
|
||||
|
||||
withClue("transformDown") {
|
||||
|
|
|
@ -33,15 +33,16 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
|
||||
import org.apache.spark.sql.execution.metric.SQLMetric
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
object SparkPlan {
|
||||
// a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag
|
||||
// when converting a logical plan to a physical plan.
|
||||
val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
|
||||
val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -69,7 +69,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case ReturnAnswer(rootPlan) => rootPlan
|
||||
case _ => plan
|
||||
}
|
||||
p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
|
||||
p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
|
||||
p
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import scala.reflect.ClassTag
|
|||
|
||||
import org.apache.spark.sql.TPCDSQuerySuite
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window}
|
||||
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
|
||||
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
|
||||
|
@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
|
|||
// The exchange related nodes are created after the planning, they don't have corresponding
|
||||
// logical plan.
|
||||
case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec =>
|
||||
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
|
||||
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
|
||||
|
||||
// The subquery exec nodes are just wrappers of the actual nodes, they don't have
|
||||
// corresponding logical plan.
|
||||
case _: SubqueryExec | _: ReusedSubqueryExec =>
|
||||
assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
|
||||
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
|
||||
|
||||
case _ if isScanPlanTree(plan) =>
|
||||
// The strategies for planning scan can remove or add FilterExec/ProjectExec nodes,
|
||||
|
@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
|
|||
}
|
||||
|
||||
private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
|
||||
assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
|
||||
node.getClass.getSimpleName + " does not have a logical plan link")
|
||||
node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
|
||||
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
|
||||
fail(node.getClass.getSimpleName + " does not have a logical plan link")
|
||||
}
|
||||
}
|
||||
|
||||
private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = {
|
||||
|
|
Loading…
Reference in a new issue