[SPARK-35940][SQL] Refactor EquivalentExpressions to make it more efficient
### What changes were proposed in this pull request?
This PR uses 2 ideas to make `EquivalentExpressions` more efficient:
1. do not keep all the equivalent expressions, we only need a count
2. track the "height" of common subexpressions, to quickly do child-parent sort, and filter out non-child expressions in `addCommonExprs`
This PR also fixes several small bugs (exposed by the refactoring), please see PR comments.
### Why are the changes needed?
code cleanup and small perf improvement
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
Closes #33142 from cloud-fan/codegen.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit e6ce220690
)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
parent
8bc54c2d6d
commit
ec84982191
|
@ -29,20 +29,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
|
||||||
* considered equal if for the same input(s), the same result is produced.
|
* considered equal if for the same input(s), the same result is produced.
|
||||||
*/
|
*/
|
||||||
class EquivalentExpressions {
|
class EquivalentExpressions {
|
||||||
/**
|
|
||||||
* Wrapper around an Expression that provides semantic equality.
|
|
||||||
*/
|
|
||||||
case class Expr(e: Expression) {
|
|
||||||
override def equals(o: Any): Boolean = o match {
|
|
||||||
case other: Expr => e.semanticEquals(other.e)
|
|
||||||
case _ => false
|
|
||||||
}
|
|
||||||
|
|
||||||
override def hashCode: Int = e.semanticHash()
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each expression, the set of equivalent expressions.
|
// For each expression, the set of equivalent expressions.
|
||||||
private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
|
private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds each expression to this data structure, grouping them with existing equivalent
|
* Adds each expression to this data structure, grouping them with existing equivalent
|
||||||
|
@ -50,28 +38,19 @@ class EquivalentExpressions {
|
||||||
* Returns true if there was already a matching expression.
|
* Returns true if there was already a matching expression.
|
||||||
*/
|
*/
|
||||||
def addExpr(expr: Expression): Boolean = {
|
def addExpr(expr: Expression): Boolean = {
|
||||||
if (expr.deterministic) {
|
addExprToMap(expr, equivalenceMap)
|
||||||
val e: Expr = Expr(expr)
|
|
||||||
val f = equivalenceMap.get(e)
|
|
||||||
if (f.isDefined) {
|
|
||||||
f.get += expr
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
equivalenceMap.put(e, mutable.ArrayBuffer(expr))
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = {
|
private def addExprToMap(
|
||||||
|
expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = {
|
||||||
if (expr.deterministic) {
|
if (expr.deterministic) {
|
||||||
val e = Expr(expr)
|
val wrapper = ExpressionEquals(expr)
|
||||||
if (set.contains(e)) {
|
map.get(wrapper) match {
|
||||||
|
case Some(stats) =>
|
||||||
|
stats.useCount += 1
|
||||||
true
|
true
|
||||||
} else {
|
case _ =>
|
||||||
set.add(e)
|
map.put(wrapper, ExpressionStats(expr)())
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -93,25 +72,33 @@ class EquivalentExpressions {
|
||||||
*/
|
*/
|
||||||
private def addCommonExprs(
|
private def addCommonExprs(
|
||||||
exprs: Seq[Expression],
|
exprs: Seq[Expression],
|
||||||
addFunc: Expression => Boolean = addExpr): Unit = {
|
map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
|
||||||
val exprSetForAll = mutable.Set[Expr]()
|
assert(exprs.length > 1)
|
||||||
addExprTree(exprs.head, addExprToSet(_, exprSetForAll))
|
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
|
||||||
|
addExprTree(exprs.head, localEquivalenceMap)
|
||||||
|
|
||||||
val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
|
exprs.tail.foreach { expr =>
|
||||||
val otherExprSet = mutable.Set[Expr]()
|
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
|
||||||
addExprTree(expr, addExprToSet(_, otherExprSet))
|
addExprTree(expr, otherLocalEquivalenceMap)
|
||||||
exprSet.intersect(otherExprSet)
|
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
|
||||||
}
|
otherLocalEquivalenceMap.contains(key)
|
||||||
|
|
||||||
// Not all expressions in the set should be added. We should filter out the related
|
|
||||||
// children nodes.
|
|
||||||
val commonExprSet = candidateExprs.filter { candidateExpr =>
|
|
||||||
candidateExprs.forall { expr =>
|
|
||||||
expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
|
localEquivalenceMap.foreach { case (commonExpr, state) =>
|
||||||
|
val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height }
|
||||||
|
val notChild = possibleParents.forall { case (k, _) =>
|
||||||
|
k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
|
||||||
|
}
|
||||||
|
if (notChild) {
|
||||||
|
// If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will
|
||||||
|
// increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree`
|
||||||
|
// will recursively add `commonExpr` and its descendant to the equivalence map, in case
|
||||||
|
// they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`,
|
||||||
|
// `a + b` also appears in the condition and should be treated as common subexpression.
|
||||||
|
addExprTree(commonExpr.e, map)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are some special expressions that we should not recurse into all of its children.
|
// There are some special expressions that we should not recurse into all of its children.
|
||||||
|
@ -135,6 +122,7 @@ class EquivalentExpressions {
|
||||||
// For some special expressions we cannot just recurse into all of its children, but we can
|
// For some special expressions we cannot just recurse into all of its children, but we can
|
||||||
// recursively add the common expressions shared between all of its children.
|
// recursively add the common expressions shared between all of its children.
|
||||||
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
|
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
|
||||||
|
case _: CodegenFallback => Nil
|
||||||
case i: If => Seq(Seq(i.trueValue, i.falseValue))
|
case i: If => Seq(Seq(i.trueValue, i.falseValue))
|
||||||
case c: CaseWhen =>
|
case c: CaseWhen =>
|
||||||
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
|
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
|
||||||
|
@ -142,7 +130,13 @@ class EquivalentExpressions {
|
||||||
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
|
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
|
||||||
// a subexpression among values doesn't need to be in conditions because no matter which
|
// a subexpression among values doesn't need to be in conditions because no matter which
|
||||||
// condition is true, it will be evaluated.
|
// condition is true, it will be evaluated.
|
||||||
val conditions = c.branches.tail.map(_._1)
|
val conditions = if (c.branches.length > 1) {
|
||||||
|
c.branches.map(_._1)
|
||||||
|
} else {
|
||||||
|
// If there is only one branch, the first condition is already covered by
|
||||||
|
// `childrenToRecurse` and we should exclude it here.
|
||||||
|
Nil
|
||||||
|
}
|
||||||
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
|
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
|
||||||
// the elseValue.
|
// the elseValue.
|
||||||
val values = if (c.elseValue.nonEmpty) {
|
val values = if (c.elseValue.nonEmpty) {
|
||||||
|
@ -150,8 +144,11 @@ class EquivalentExpressions {
|
||||||
} else {
|
} else {
|
||||||
Nil
|
Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Seq(conditions, values)
|
Seq(conditions, values)
|
||||||
case c: Coalesce => Seq(c.children.tail)
|
// If there is only one child, the first child is already covered by
|
||||||
|
// `childrenToRecurse` and we should exclude it here.
|
||||||
|
case c: Coalesce if c.children.length > 1 => Seq(c.children)
|
||||||
case _ => Nil
|
case _ => Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +158,7 @@ class EquivalentExpressions {
|
||||||
*/
|
*/
|
||||||
def addExprTree(
|
def addExprTree(
|
||||||
expr: Expression,
|
expr: Expression,
|
||||||
addFunc: Expression => Boolean = addExpr): Unit = {
|
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
|
||||||
val skip = expr.isInstanceOf[LeafExpression] ||
|
val skip = expr.isInstanceOf[LeafExpression] ||
|
||||||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
|
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
|
||||||
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
|
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
|
||||||
|
@ -170,27 +167,30 @@ class EquivalentExpressions {
|
||||||
// can cause error like NPE.
|
// can cause error like NPE.
|
||||||
(expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
|
(expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
|
||||||
|
|
||||||
if (!skip && !addFunc(expr)) {
|
if (!skip && !addExprToMap(expr, map)) {
|
||||||
childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
|
childrenToRecurse(expr).foreach(addExprTree(_, map))
|
||||||
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
|
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns all of the expression trees that are equivalent to `e`. Returns
|
* Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no
|
||||||
* an empty collection if there are none.
|
* equivalent expressions.
|
||||||
*/
|
*/
|
||||||
def getEquivalentExprs(e: Expression): Seq[Expression] = {
|
def getExprState(e: Expression): Option[ExpressionStats] = {
|
||||||
equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq
|
equivalenceMap.get(ExpressionEquals(e))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exposed for testing.
|
||||||
|
private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
|
||||||
|
equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns all the equivalent sets of expressions which appear more than given `repeatTimes`
|
* Returns a sequence of expressions that more than one equivalent expressions.
|
||||||
* times.
|
|
||||||
*/
|
*/
|
||||||
def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
|
def getCommonSubexpressions: Seq[Expression] = {
|
||||||
equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
|
getAllExprStates(1).map(_.expr)
|
||||||
.sortBy(_.head)(new ExpressionContainmentOrdering)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -198,37 +198,40 @@ class EquivalentExpressions {
|
||||||
* equivalent expressions with cardinality 1.
|
* equivalent expressions with cardinality 1.
|
||||||
*/
|
*/
|
||||||
def debugString(all: Boolean = false): String = {
|
def debugString(all: Boolean = false): String = {
|
||||||
val sb: mutable.StringBuilder = new StringBuilder()
|
val sb = new java.lang.StringBuilder()
|
||||||
sb.append("Equivalent expressions:\n")
|
sb.append("Equivalent expressions:\n")
|
||||||
equivalenceMap.foreach { case (k, v) =>
|
equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats =>
|
||||||
if (all || v.length > 1) {
|
sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n')
|
||||||
sb.append(" " + v.mkString(", ")).append("\n")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
sb.toString()
|
sb.toString()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Orders `Expression` by parent/child relations. The child expression is smaller
|
* Wrapper around an Expression that provides semantic equality.
|
||||||
* than parent expression. If there is child-parent relationships among the subexpressions,
|
|
||||||
* we want the child expressions come first than parent expressions, so we can replace
|
|
||||||
* child expressions in parent expressions with subexpression evaluation. Note that
|
|
||||||
* this is not for general expression ordering. For example, two irrelevant or semantically-equal
|
|
||||||
* expressions will be considered as equal by this ordering. But for the usage here, the order of
|
|
||||||
* irrelevant expressions does not matter.
|
|
||||||
*/
|
*/
|
||||||
class ExpressionContainmentOrdering extends Ordering[Expression] {
|
case class ExpressionEquals(e: Expression) {
|
||||||
override def compare(x: Expression, y: Expression): Int = {
|
override def equals(o: Any): Boolean = o match {
|
||||||
if (x.find(_.semanticEquals(y)).isDefined) {
|
case other: ExpressionEquals => e.semanticEquals(other.e)
|
||||||
// `y` is child expression of `x`.
|
case _ => false
|
||||||
1
|
|
||||||
} else if (y.find(_.semanticEquals(x)).isDefined) {
|
|
||||||
// `x` is child expression of `y`.
|
|
||||||
-1
|
|
||||||
} else {
|
|
||||||
// Irrelevant or semantically-equal expressions
|
|
||||||
0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def hashCode: Int = e.semanticHash()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A wrapper in place of using Seq[Expression] to record a group of equivalent expressions.
|
||||||
|
*
|
||||||
|
* This saves a lot of memory when there are a lot of expressions in a same equivalence group.
|
||||||
|
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
|
||||||
|
* useCount in this wrapper in-place.
|
||||||
|
*/
|
||||||
|
case class ExpressionStats(expr: Expression)(var useCount: Int = 1) {
|
||||||
|
// This is used to do a fast pre-check for child-parent relationship. For example, expr1 can
|
||||||
|
// only be a parent of expr2 if expr1.height is larger than expr2.height.
|
||||||
|
lazy val height = getHeight(expr)
|
||||||
|
|
||||||
|
private def getHeight(tree: Expression): Int = {
|
||||||
|
tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,7 +136,7 @@ abstract class Expression extends TreeNode[Expression] {
|
||||||
* @return [[ExprCode]]
|
* @return [[ExprCode]]
|
||||||
*/
|
*/
|
||||||
def genCode(ctx: CodegenContext): ExprCode = {
|
def genCode(ctx: CodegenContext): ExprCode = {
|
||||||
ctx.subExprEliminationExprs.get(this).map { subExprState =>
|
ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState =>
|
||||||
// This expression is repeated which means that the code to evaluate it has already been added
|
// This expression is repeated which means that the code to evaluate it has already been added
|
||||||
// as a function before. In that case, we just re-use it.
|
// as a function before. In that case, we just re-use it.
|
||||||
ExprCode(
|
ExprCode(
|
||||||
|
|
|
@ -73,11 +73,11 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
|
||||||
*/
|
*/
|
||||||
private def replaceWithProxy(
|
private def replaceWithProxy(
|
||||||
expr: Expression,
|
expr: Expression,
|
||||||
|
equivalentExpressions: EquivalentExpressions,
|
||||||
proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
|
proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
|
||||||
if (proxyMap.containsKey(expr)) {
|
equivalentExpressions.getExprState(expr) match {
|
||||||
proxyMap.get(expr)
|
case Some(stats) if proxyMap.containsKey(stats.expr) => proxyMap.get(stats.expr)
|
||||||
} else {
|
case _ => expr.mapChildren(replaceWithProxy(_, equivalentExpressions, proxyMap))
|
||||||
expr.mapChildren(replaceWithProxy(_, proxyMap))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,9 +91,8 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
|
||||||
|
|
||||||
val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
|
val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
|
||||||
|
|
||||||
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
|
val commonExprs = equivalentExpressions.getCommonSubexpressions
|
||||||
commonExprs.foreach { e =>
|
commonExprs.foreach { expr =>
|
||||||
val expr = e.head
|
|
||||||
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
|
val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
|
||||||
proxyExpressionCurrentId += 1
|
proxyExpressionCurrentId += 1
|
||||||
|
|
||||||
|
@ -102,12 +101,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
|
||||||
// common expr2, ..., common expr n), we will insert into `proxyMap` some key/value
|
// common expr2, ..., common expr n), we will insert into `proxyMap` some key/value
|
||||||
// pairs like Map(common expr 1 -> proxy(common expr 1), ...,
|
// pairs like Map(common expr 1 -> proxy(common expr 1), ...,
|
||||||
// common expr n -> proxy(common expr 1)).
|
// common expr n -> proxy(common expr 1)).
|
||||||
e.map(proxyMap.put(_, proxy))
|
proxyMap.put(expr, proxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only adding proxy if we find subexpressions.
|
// Only adding proxy if we find subexpressions.
|
||||||
if (!proxyMap.isEmpty) {
|
if (!proxyMap.isEmpty) {
|
||||||
expressions.map(replaceWithProxy(_, proxyMap))
|
expressions.map(replaceWithProxy(_, equivalentExpressions, proxyMap))
|
||||||
} else {
|
} else {
|
||||||
expressions
|
expressions
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,9 +83,7 @@ object ExprCode {
|
||||||
* particular subexpressions, instead of all at once. In the case, we need
|
* particular subexpressions, instead of all at once. In the case, we need
|
||||||
* to make sure we evaluate all children subexpressions too.
|
* to make sure we evaluate all children subexpressions too.
|
||||||
*/
|
*/
|
||||||
case class SubExprEliminationState(
|
case class SubExprEliminationState(eval: ExprCode, children: Seq[SubExprEliminationState])
|
||||||
eval: ExprCode,
|
|
||||||
children: Seq[SubExprEliminationState])
|
|
||||||
|
|
||||||
object SubExprEliminationState {
|
object SubExprEliminationState {
|
||||||
def apply(eval: ExprCode): SubExprEliminationState = {
|
def apply(eval: ExprCode): SubExprEliminationState = {
|
||||||
|
@ -108,7 +106,7 @@ object SubExprEliminationState {
|
||||||
* calling common subexpressions.
|
* calling common subexpressions.
|
||||||
*/
|
*/
|
||||||
case class SubExprCodes(
|
case class SubExprCodes(
|
||||||
states: Map[Expression, SubExprEliminationState],
|
states: Map[ExpressionEquals, SubExprEliminationState],
|
||||||
exprCodesNeedEvaluate: Seq[ExprCode])
|
exprCodesNeedEvaluate: Seq[ExprCode])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -426,7 +424,8 @@ class CodegenContext extends Logging {
|
||||||
|
|
||||||
// Foreach expression that is participating in subexpression elimination, the state to use.
|
// Foreach expression that is participating in subexpression elimination, the state to use.
|
||||||
// Visible for testing.
|
// Visible for testing.
|
||||||
private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]
|
private[expressions] var subExprEliminationExprs =
|
||||||
|
Map.empty[ExpressionEquals, SubExprEliminationState]
|
||||||
|
|
||||||
// The collection of sub-expression result resetting methods that need to be called on each row.
|
// The collection of sub-expression result resetting methods that need to be called on each row.
|
||||||
private val subexprFunctions = mutable.ArrayBuffer.empty[String]
|
private val subexprFunctions = mutable.ArrayBuffer.empty[String]
|
||||||
|
@ -1031,7 +1030,7 @@ class CodegenContext extends Logging {
|
||||||
* expressions and common expressions, instead of using the mapping in current context.
|
* expressions and common expressions, instead of using the mapping in current context.
|
||||||
*/
|
*/
|
||||||
def withSubExprEliminationExprs(
|
def withSubExprEliminationExprs(
|
||||||
newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
|
newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])(
|
||||||
f: => Seq[ExprCode]): Seq[ExprCode] = {
|
f: => Seq[ExprCode]): Seq[ExprCode] = {
|
||||||
val oldsubExprEliminationExprs = subExprEliminationExprs
|
val oldsubExprEliminationExprs = subExprEliminationExprs
|
||||||
subExprEliminationExprs = newSubExprEliminationExprs
|
subExprEliminationExprs = newSubExprEliminationExprs
|
||||||
|
@ -1098,29 +1097,30 @@ class CodegenContext extends Logging {
|
||||||
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
|
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
|
||||||
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
|
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
|
||||||
val localSubExprEliminationExprsForNonSplit =
|
val localSubExprEliminationExprsForNonSplit =
|
||||||
mutable.HashMap.empty[Expression, SubExprEliminationState]
|
mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]
|
||||||
|
|
||||||
// Add each expression tree and compute the common subexpressions.
|
// Add each expression tree and compute the common subexpressions.
|
||||||
expressions.foreach(equivalentExpressions.addExprTree(_))
|
expressions.foreach(equivalentExpressions.addExprTree(_))
|
||||||
|
|
||||||
// Get all the expressions that appear at least twice and set up the state for subexpression
|
// Get all the expressions that appear at least twice and set up the state for subexpression
|
||||||
// elimination.
|
// elimination.
|
||||||
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
|
val commonExprs = equivalentExpressions.getCommonSubexpressions
|
||||||
|
|
||||||
val nonSplitCode = {
|
val nonSplitCode = {
|
||||||
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
||||||
commonExprs.map { exprs =>
|
commonExprs.map { expr =>
|
||||||
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
|
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
|
||||||
val eval = exprs.head.genCode(this)
|
val eval = expr.genCode(this)
|
||||||
// Collects other subexpressions from the children.
|
// Collects other subexpressions from the children.
|
||||||
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
||||||
exprs.head.foreach {
|
expr.foreach { e =>
|
||||||
case e if subExprEliminationExprs.contains(e) =>
|
subExprEliminationExprs.get(ExpressionEquals(e)) match {
|
||||||
childrenSubExprs += subExprEliminationExprs(e)
|
case Some(state) => childrenSubExprs += state
|
||||||
case _ =>
|
case _ =>
|
||||||
}
|
}
|
||||||
|
}
|
||||||
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
|
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
|
||||||
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
|
localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state)
|
||||||
allStates += state
|
allStates += state
|
||||||
Seq(eval)
|
Seq(eval)
|
||||||
}
|
}
|
||||||
|
@ -1133,7 +1133,7 @@ class CodegenContext extends Logging {
|
||||||
// evaluate the outputs used more than twice. So we need to extract these variables used by
|
// evaluate the outputs used more than twice. So we need to extract these variables used by
|
||||||
// subexpressions and evaluate them before subexpressions.
|
// subexpressions and evaluate them before subexpressions.
|
||||||
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
|
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
|
||||||
val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head)
|
val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr)
|
||||||
(inputVars.toSeq, exprCodes.toSeq)
|
(inputVars.toSeq, exprCodes.toSeq)
|
||||||
}.unzip
|
}.unzip
|
||||||
|
|
||||||
|
@ -1141,10 +1141,9 @@ class CodegenContext extends Logging {
|
||||||
val (subExprsMap, exprCodes) = if (needSplit) {
|
val (subExprsMap, exprCodes) = if (needSplit) {
|
||||||
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
|
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
|
||||||
val localSubExprEliminationExprs =
|
val localSubExprEliminationExprs =
|
||||||
mutable.HashMap.empty[Expression, SubExprEliminationState]
|
mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]
|
||||||
|
|
||||||
commonExprs.zipWithIndex.foreach { case (exprs, i) =>
|
commonExprs.zipWithIndex.foreach { case (expr, i) =>
|
||||||
val expr = exprs.head
|
|
||||||
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
|
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
|
||||||
Seq(expr.genCode(this))
|
Seq(expr.genCode(this))
|
||||||
}.head
|
}.head
|
||||||
|
@ -1178,18 +1177,19 @@ class CodegenContext extends Logging {
|
||||||
|
|
||||||
// Collects other subexpressions from the children.
|
// Collects other subexpressions from the children.
|
||||||
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
|
||||||
exprs.head.foreach {
|
expr.foreach { e =>
|
||||||
case e if localSubExprEliminationExprs.contains(e) =>
|
localSubExprEliminationExprs.get(ExpressionEquals(e)) match {
|
||||||
childrenSubExprs += localSubExprEliminationExprs(e)
|
case Some(state) => childrenSubExprs += state
|
||||||
case _ =>
|
case _ =>
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
val inputVariables = inputVars.map(_.variableName).mkString(", ")
|
val inputVariables = inputVars.map(_.variableName).mkString(", ")
|
||||||
val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
|
val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
|
||||||
val state = SubExprEliminationState(
|
val state = SubExprEliminationState(
|
||||||
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
|
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
|
||||||
childrenSubExprs.toSeq)
|
childrenSubExprs.toSeq)
|
||||||
exprs.foreach(localSubExprEliminationExprs.put(_, state))
|
localSubExprEliminationExprs.put(ExpressionEquals(expr), state)
|
||||||
}
|
}
|
||||||
(localSubExprEliminationExprs, exprCodesNeedEvaluate)
|
(localSubExprEliminationExprs, exprCodesNeedEvaluate)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1217,9 +1217,8 @@ class CodegenContext extends Logging {
|
||||||
|
|
||||||
// Get all the expressions that appear at least twice and set up the state for subexpression
|
// Get all the expressions that appear at least twice and set up the state for subexpression
|
||||||
// elimination.
|
// elimination.
|
||||||
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
|
val commonExprs = equivalentExpressions.getCommonSubexpressions
|
||||||
commonExprs.foreach { e =>
|
commonExprs.foreach { expr =>
|
||||||
val expr = e.head
|
|
||||||
val fnName = freshName("subExpr")
|
val fnName = freshName("subExpr")
|
||||||
val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
|
val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
|
||||||
val value = addMutableState(javaType(expr.dataType), "subExprValue")
|
val value = addMutableState(javaType(expr.dataType), "subExprValue")
|
||||||
|
@ -1255,7 +1254,7 @@ class CodegenContext extends Logging {
|
||||||
ExprCode(code"$subExprCode",
|
ExprCode(code"$subExprCode",
|
||||||
JavaCode.isNullGlobal(isNull),
|
JavaCode.isNullGlobal(isNull),
|
||||||
JavaCode.global(value, expr.dataType)))
|
JavaCode.global(value, expr.dataType)))
|
||||||
subExprEliminationExprs ++= e.map(_ -> state).toMap
|
subExprEliminationExprs += ExpressionEquals(expr) -> state
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1834,7 +1833,7 @@ object CodeGenerator extends Logging {
|
||||||
def getLocalInputVariableValues(
|
def getLocalInputVariableValues(
|
||||||
ctx: CodegenContext,
|
ctx: CodegenContext,
|
||||||
expr: Expression,
|
expr: Expression,
|
||||||
subExprs: Map[Expression, SubExprEliminationState] = Map.empty)
|
subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty)
|
||||||
: (Set[VariableValue], Set[ExprCode]) = {
|
: (Set[VariableValue], Set[ExprCode]) = {
|
||||||
val argSet = mutable.Set[VariableValue]()
|
val argSet = mutable.Set[VariableValue]()
|
||||||
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
|
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
|
||||||
|
@ -1852,10 +1851,6 @@ object CodeGenerator extends Logging {
|
||||||
val stack = mutable.Stack[Expression](expr)
|
val stack = mutable.Stack[Expression](expr)
|
||||||
while (stack.nonEmpty) {
|
while (stack.nonEmpty) {
|
||||||
stack.pop() match {
|
stack.pop() match {
|
||||||
case e if subExprs.contains(e) =>
|
|
||||||
collectLocalVariable(subExprs(e).eval.value)
|
|
||||||
collectLocalVariable(subExprs(e).eval.isNull)
|
|
||||||
|
|
||||||
case ref: BoundReference if ctx.currentVars != null &&
|
case ref: BoundReference if ctx.currentVars != null &&
|
||||||
ctx.currentVars(ref.ordinal) != null =>
|
ctx.currentVars(ref.ordinal) != null =>
|
||||||
val exprCode = ctx.currentVars(ref.ordinal)
|
val exprCode = ctx.currentVars(ref.ordinal)
|
||||||
|
@ -1868,9 +1863,15 @@ object CodeGenerator extends Logging {
|
||||||
collectLocalVariable(exprCode.isNull)
|
collectLocalVariable(exprCode.isNull)
|
||||||
|
|
||||||
case e =>
|
case e =>
|
||||||
|
subExprs.get(ExpressionEquals(e)) match {
|
||||||
|
case Some(state) =>
|
||||||
|
collectLocalVariable(state.eval.value)
|
||||||
|
collectLocalVariable(state.eval.isNull)
|
||||||
|
case None =>
|
||||||
stack.pushAll(e.children)
|
stack.pushAll(e.children)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
(argSet.toSet, exprCodesNeedEvaluate.toSet)
|
(argSet.toSet, exprCodesNeedEvaluate.toSet)
|
||||||
}
|
}
|
||||||
|
|
|
@ -325,11 +325,11 @@ object PhysicalAggregation {
|
||||||
case ae: AggregateExpression =>
|
case ae: AggregateExpression =>
|
||||||
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
|
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
|
||||||
// so replace each aggregate expression by its corresponding attribute in the set:
|
// so replace each aggregate expression by its corresponding attribute in the set:
|
||||||
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
|
equivalentAggregateExpressions.getExprState(ae).map(_.expr)
|
||||||
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
|
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
|
||||||
// Similar to AggregateExpression
|
// Similar to AggregateExpression
|
||||||
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
|
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
|
||||||
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
|
equivalentAggregateExpressions.getExprState(ue).map(_.expr)
|
||||||
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
|
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
|
||||||
case expression if !expression.foldable =>
|
case expression if !expression.foldable =>
|
||||||
// Since we're using `namedGroupingAttributes` to extract the grouping key
|
// Since we're using `namedGroupingAttributes` to extract the grouping key
|
||||||
|
|
|
@ -457,6 +457,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
|
Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def wrap(expr: Expression): ExpressionEquals = ExpressionEquals(expr)
|
||||||
|
|
||||||
test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") {
|
test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") {
|
||||||
|
|
||||||
val ref = BoundReference(0, IntegerType, true)
|
val ref = BoundReference(0, IntegerType, true)
|
||||||
|
@ -472,19 +474,19 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
val ctx = new CodegenContext
|
val ctx = new CodegenContext
|
||||||
val e = ref.genCode(ctx)
|
val e = ref.genCode(ctx)
|
||||||
// before
|
// before
|
||||||
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(
|
ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState(
|
||||||
ExprCode(EmptyBlock, e.isNull, e.value))
|
ExprCode(EmptyBlock, e.isNull, e.value))
|
||||||
assert(ctx.subExprEliminationExprs.contains(ref))
|
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
|
||||||
// call withSubExprEliminationExprs
|
// call withSubExprEliminationExprs
|
||||||
ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {
|
ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) {
|
||||||
assert(ctx.subExprEliminationExprs.contains(add1))
|
assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
|
||||||
assert(!ctx.subExprEliminationExprs.contains(ref))
|
assert(!ctx.subExprEliminationExprs.contains(wrap(ref)))
|
||||||
Seq.empty
|
Seq.empty
|
||||||
}
|
}
|
||||||
// after
|
// after
|
||||||
assert(ctx.subExprEliminationExprs.nonEmpty)
|
assert(ctx.subExprEliminationExprs.nonEmpty)
|
||||||
assert(ctx.subExprEliminationExprs.contains(ref))
|
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
|
||||||
assert(!ctx.subExprEliminationExprs.contains(add1))
|
assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// emulate an actual codegen workload
|
// emulate an actual codegen workload
|
||||||
|
@ -492,17 +494,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
val ctx = new CodegenContext
|
val ctx = new CodegenContext
|
||||||
// before
|
// before
|
||||||
ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE
|
ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE
|
||||||
assert(ctx.subExprEliminationExprs.contains(add1))
|
assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
|
||||||
// call withSubExprEliminationExprs
|
// call withSubExprEliminationExprs
|
||||||
ctx.withSubExprEliminationExprs(Map(ref -> dummy)) {
|
ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) {
|
||||||
assert(ctx.subExprEliminationExprs.contains(ref))
|
assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
|
||||||
assert(!ctx.subExprEliminationExprs.contains(add1))
|
assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
|
||||||
Seq.empty
|
Seq.empty
|
||||||
}
|
}
|
||||||
// after
|
// after
|
||||||
assert(ctx.subExprEliminationExprs.nonEmpty)
|
assert(ctx.subExprEliminationExprs.nonEmpty)
|
||||||
assert(ctx.subExprEliminationExprs.contains(add1))
|
assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
|
||||||
assert(!ctx.subExprEliminationExprs.contains(ref))
|
assert(!ctx.subExprEliminationExprs.contains(wrap(ref)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -47,35 +47,32 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
|
|
||||||
test("Expression Equivalence - basic") {
|
test("Expression Equivalence - basic") {
|
||||||
val equivalence = new EquivalentExpressions
|
val equivalence = new EquivalentExpressions
|
||||||
assert(equivalence.getAllEquivalentExprs().isEmpty)
|
assert(equivalence.getAllExprStates().isEmpty)
|
||||||
|
|
||||||
val oneA = Literal(1)
|
val oneA = Literal(1)
|
||||||
val oneB = Literal(1)
|
val oneB = Literal(1)
|
||||||
val twoA = Literal(2)
|
val twoA = Literal(2)
|
||||||
var twoB = Literal(2)
|
var twoB = Literal(2)
|
||||||
|
|
||||||
assert(equivalence.getEquivalentExprs(oneA).isEmpty)
|
assert(equivalence.getExprState(oneA).isEmpty)
|
||||||
assert(equivalence.getEquivalentExprs(twoA).isEmpty)
|
assert(equivalence.getExprState(twoA).isEmpty)
|
||||||
|
|
||||||
// Add oneA and test if it is returned. Since it is a group of one, it does not.
|
// Add oneA and test if it is returned. Since it is a group of one, it does not.
|
||||||
assert(!equivalence.addExpr(oneA))
|
assert(!equivalence.addExpr(oneA))
|
||||||
assert(equivalence.getEquivalentExprs(oneA).size == 1)
|
assert(equivalence.getExprState(oneA).get.useCount == 1)
|
||||||
assert(equivalence.getEquivalentExprs(twoA).isEmpty)
|
assert(equivalence.getExprState(twoA).isEmpty)
|
||||||
assert(equivalence.addExpr((oneA)))
|
assert(equivalence.addExpr(oneA))
|
||||||
assert(equivalence.getEquivalentExprs(oneA).size == 2)
|
assert(equivalence.getExprState(oneA).get.useCount == 2)
|
||||||
|
|
||||||
// Add B and make sure they can see each other.
|
// Add B and make sure they can see each other.
|
||||||
assert(equivalence.addExpr(oneB))
|
assert(equivalence.addExpr(oneB))
|
||||||
// Use exists and reference equality because of how equals is defined.
|
// Use exists and reference equality because of how equals is defined.
|
||||||
assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB))
|
assert(equivalence.getExprState(oneA).exists(_.expr eq oneA))
|
||||||
assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA))
|
assert(equivalence.getExprState(oneB).exists(_.expr eq oneA))
|
||||||
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
|
assert(equivalence.getExprState(twoA).isEmpty)
|
||||||
assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
|
assert(equivalence.getAllExprStates().size == 1)
|
||||||
assert(equivalence.getEquivalentExprs(twoA).isEmpty)
|
assert(equivalence.getAllExprStates().head.useCount == 3)
|
||||||
assert(equivalence.getAllEquivalentExprs().size == 1)
|
assert(equivalence.getAllExprStates().head.expr eq oneA)
|
||||||
assert(equivalence.getAllEquivalentExprs().head.size == 3)
|
|
||||||
assert(equivalence.getAllEquivalentExprs().head.contains(oneA))
|
|
||||||
assert(equivalence.getAllEquivalentExprs().head.contains(oneB))
|
|
||||||
|
|
||||||
val add1 = Add(oneA, oneB)
|
val add1 = Add(oneA, oneB)
|
||||||
val add2 = Add(oneA, oneB)
|
val add2 = Add(oneA, oneB)
|
||||||
|
@ -83,10 +80,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence.addExpr(add1)
|
equivalence.addExpr(add1)
|
||||||
equivalence.addExpr(add2)
|
equivalence.addExpr(add2)
|
||||||
|
|
||||||
assert(equivalence.getAllEquivalentExprs().size == 2)
|
assert(equivalence.getAllExprStates().size == 2)
|
||||||
assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
|
assert(equivalence.getExprState(add1).exists(_.expr eq add1))
|
||||||
assert(equivalence.getEquivalentExprs(add2).size == 2)
|
assert(equivalence.getExprState(add2).get.useCount == 2)
|
||||||
assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
|
assert(equivalence.getExprState(add2).exists(_.expr eq add1))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Expression Equivalence - Trees") {
|
test("Expression Equivalence - Trees") {
|
||||||
|
@ -103,8 +100,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence.addExprTree(add2)
|
equivalence.addExprTree(add2)
|
||||||
|
|
||||||
// Should only have one equivalence for `one + two`
|
// Should only have one equivalence for `one + two`
|
||||||
assert(equivalence.getAllEquivalentExprs(1).size == 1)
|
assert(equivalence.getAllExprStates(1).size == 1)
|
||||||
assert(equivalence.getAllEquivalentExprs(1).head.size == 4)
|
assert(equivalence.getAllExprStates(1).head.useCount == 4)
|
||||||
|
|
||||||
// Set up the expressions
|
// Set up the expressions
|
||||||
// one * two,
|
// one * two,
|
||||||
|
@ -122,11 +119,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence.addExprTree(sum)
|
equivalence.addExprTree(sum)
|
||||||
|
|
||||||
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
|
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
|
||||||
assert(equivalence.getAllEquivalentExprs(1).size == 3)
|
assert(equivalence.getAllExprStates(1).size == 3)
|
||||||
assert(equivalence.getEquivalentExprs(mul).size == 3)
|
assert(equivalence.getExprState(mul).get.useCount == 3)
|
||||||
assert(equivalence.getEquivalentExprs(mul2).size == 3)
|
assert(equivalence.getExprState(mul2).get.useCount == 3)
|
||||||
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
|
assert(equivalence.getExprState(sqrt).get.useCount == 2)
|
||||||
assert(equivalence.getEquivalentExprs(sum).size == 1)
|
assert(equivalence.getExprState(sum).get.useCount == 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Expression equivalence - non deterministic") {
|
test("Expression equivalence - non deterministic") {
|
||||||
|
@ -134,7 +131,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val equivalence = new EquivalentExpressions
|
val equivalence = new EquivalentExpressions
|
||||||
equivalence.addExpr(sum)
|
equivalence.addExpr(sum)
|
||||||
equivalence.addExpr(sum)
|
equivalence.addExpr(sum)
|
||||||
assert(equivalence.getAllEquivalentExprs().isEmpty)
|
assert(equivalence.getAllExprStates().isEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Children of CodegenFallback") {
|
test("Children of CodegenFallback") {
|
||||||
|
@ -146,8 +143,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val equivalence = new EquivalentExpressions
|
val equivalence = new EquivalentExpressions
|
||||||
equivalence.addExprTree(add)
|
equivalence.addExprTree(add)
|
||||||
// the `two` inside `fallback` should not be added
|
// the `two` inside `fallback` should not be added
|
||||||
assert(equivalence.getAllEquivalentExprs(1).size == 0)
|
assert(equivalence.getAllExprStates(1).size == 0)
|
||||||
assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode
|
assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Children of conditional expressions: If") {
|
test("Children of conditional expressions: If") {
|
||||||
|
@ -159,35 +156,34 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence1.addExprTree(ifExpr1)
|
equivalence1.addExprTree(ifExpr1)
|
||||||
|
|
||||||
// `add` is in both two branches of `If` and predicate.
|
// `add` is in both two branches of `If` and predicate.
|
||||||
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
|
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
|
||||||
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add))
|
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add)
|
||||||
// one-time expressions: only ifExpr and its predicate expression
|
// one-time expressions: only ifExpr and its predicate expression
|
||||||
assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2)
|
assert(equivalence1.getAllExprStates().count(_.useCount == 1) == 2)
|
||||||
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
|
assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1))
|
||||||
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition)))
|
assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition))
|
||||||
|
|
||||||
// Repeated `add` is only in one branch, so we don't count it.
|
// Repeated `add` is only in one branch, so we don't count it.
|
||||||
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
|
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
|
||||||
val equivalence2 = new EquivalentExpressions
|
val equivalence2 = new EquivalentExpressions
|
||||||
equivalence2.addExprTree(ifExpr2)
|
equivalence2.addExprTree(ifExpr2)
|
||||||
|
|
||||||
assert(equivalence2.getAllEquivalentExprs(1).size == 0)
|
assert(equivalence2.getAllExprStates(1).isEmpty)
|
||||||
assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3)
|
assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3)
|
||||||
|
|
||||||
val ifExpr3 = If(condition, ifExpr1, ifExpr1)
|
val ifExpr3 = If(condition, ifExpr1, ifExpr1)
|
||||||
val equivalence3 = new EquivalentExpressions
|
val equivalence3 = new EquivalentExpressions
|
||||||
equivalence3.addExprTree(ifExpr3)
|
equivalence3.addExprTree(ifExpr3)
|
||||||
|
|
||||||
// `add`: 2, `condition`: 2
|
// `add`: 2, `condition`: 2
|
||||||
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2)
|
assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2)
|
||||||
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add)))
|
assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition))
|
||||||
assert(
|
assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add))
|
||||||
equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition)))
|
|
||||||
|
|
||||||
// `ifExpr1`, `ifExpr3`
|
// `ifExpr1`, `ifExpr3`
|
||||||
assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2)
|
assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2)
|
||||||
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
|
assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1))
|
||||||
assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3)))
|
assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Children of conditional expressions: CaseWhen") {
|
test("Children of conditional expressions: CaseWhen") {
|
||||||
|
@ -202,8 +198,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence1.addExprTree(caseWhenExpr1)
|
equivalence1.addExprTree(caseWhenExpr1)
|
||||||
|
|
||||||
// `add2` is repeatedly in all conditions.
|
// `add2` is repeatedly in all conditions.
|
||||||
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
|
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
|
||||||
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
|
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)
|
||||||
|
|
||||||
val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
|
val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
|
||||||
(GreaterThan(add2, Literal(4)), add1) ::
|
(GreaterThan(add2, Literal(4)), add1) ::
|
||||||
|
@ -214,8 +210,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence2.addExprTree(caseWhenExpr2)
|
equivalence2.addExprTree(caseWhenExpr2)
|
||||||
|
|
||||||
// `add1` is repeatedly in all branch values, and first predicate.
|
// `add1` is repeatedly in all branch values, and first predicate.
|
||||||
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1)
|
assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1)
|
||||||
assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1))
|
assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1)
|
||||||
|
|
||||||
// Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
|
// Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
|
||||||
val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
|
val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
|
||||||
|
@ -225,7 +221,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val caseWhenExpr3 = CaseWhen(conditions3, None)
|
val caseWhenExpr3 = CaseWhen(conditions3, None)
|
||||||
val equivalence3 = new EquivalentExpressions
|
val equivalence3 = new EquivalentExpressions
|
||||||
equivalence3.addExprTree(caseWhenExpr3)
|
equivalence3.addExprTree(caseWhenExpr3)
|
||||||
assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0)
|
assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Children of conditional expressions: Coalesce") {
|
test("Children of conditional expressions: Coalesce") {
|
||||||
|
@ -240,8 +236,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence1.addExprTree(coalesceExpr1)
|
equivalence1.addExprTree(coalesceExpr1)
|
||||||
|
|
||||||
// `add2` is repeatedly in all conditions.
|
// `add2` is repeatedly in all conditions.
|
||||||
assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
|
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
|
||||||
assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
|
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)
|
||||||
|
|
||||||
// Negative case. `add1` and `add2` both are not used in all branches.
|
// Negative case. `add1` and `add2` both are not used in all branches.
|
||||||
val conditions2 = GreaterThan(add1, Literal(3)) ::
|
val conditions2 = GreaterThan(add1, Literal(3)) ::
|
||||||
|
@ -252,7 +248,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val equivalence2 = new EquivalentExpressions
|
val equivalence2 = new EquivalentExpressions
|
||||||
equivalence2.addExprTree(coalesceExpr2)
|
equivalence2.addExprTree(coalesceExpr2)
|
||||||
|
|
||||||
assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0)
|
assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
|
test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
|
||||||
|
@ -321,9 +317,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val equivalence = new EquivalentExpressions
|
val equivalence = new EquivalentExpressions
|
||||||
equivalence.addExprTree(caseWhenExpr)
|
equivalence.addExprTree(caseWhenExpr)
|
||||||
|
|
||||||
val commonExprs = equivalence.getAllEquivalentExprs(1)
|
val commonExprs = equivalence.getAllExprStates(1)
|
||||||
assert(commonExprs.size == 1)
|
assert(commonExprs.size == 1)
|
||||||
assert(commonExprs.head === Seq(add3, add3))
|
assert(commonExprs.head.useCount == 2)
|
||||||
|
assert(commonExprs.head.expr eq add3)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-35439: Children subexpr should come first than parent subexpr") {
|
test("SPARK-35439: Children subexpr should come first than parent subexpr") {
|
||||||
|
@ -332,27 +329,29 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val equivalence1 = new EquivalentExpressions
|
val equivalence1 = new EquivalentExpressions
|
||||||
|
|
||||||
equivalence1.addExprTree(add)
|
equivalence1.addExprTree(add)
|
||||||
assert(equivalence1.getAllEquivalentExprs().head === Seq(add))
|
assert(equivalence1.getAllExprStates().head.expr eq add)
|
||||||
|
|
||||||
equivalence1.addExprTree(Add(Literal(3), add))
|
equivalence1.addExprTree(Add(Literal(3), add))
|
||||||
assert(equivalence1.getAllEquivalentExprs() ===
|
assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 1))
|
||||||
Seq(Seq(add, add), Seq(Add(Literal(3), add))))
|
assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
|
||||||
|
|
||||||
equivalence1.addExprTree(Add(Literal(3), add))
|
equivalence1.addExprTree(Add(Literal(3), add))
|
||||||
assert(equivalence1.getAllEquivalentExprs() ===
|
assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 2))
|
||||||
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
|
assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
|
||||||
|
|
||||||
val equivalence2 = new EquivalentExpressions
|
val equivalence2 = new EquivalentExpressions
|
||||||
|
|
||||||
equivalence2.addExprTree(Add(Literal(3), add))
|
equivalence2.addExprTree(Add(Literal(3), add))
|
||||||
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add))))
|
assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(1, 1))
|
||||||
|
assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
|
||||||
|
|
||||||
equivalence2.addExprTree(add)
|
equivalence2.addExprTree(add)
|
||||||
assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add))))
|
assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 1))
|
||||||
|
assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
|
||||||
|
|
||||||
equivalence2.addExprTree(Add(Literal(3), add))
|
equivalence2.addExprTree(Add(Literal(3), add))
|
||||||
assert(equivalence2.getAllEquivalentExprs() ===
|
assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 2))
|
||||||
Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
|
assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an "
|
test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an "
|
||||||
|
@ -368,28 +367,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
equivalence.addExprTree(caseWhenExpr)
|
equivalence.addExprTree(caseWhenExpr)
|
||||||
|
|
||||||
// `add1` is not in the elseValue, so we can't extract it from the branches
|
// `add1` is not in the elseValue, so we can't extract it from the branches
|
||||||
assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
|
assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0)
|
||||||
}
|
|
||||||
|
|
||||||
test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") {
|
|
||||||
val exprOrdering = new ExpressionContainmentOrdering
|
|
||||||
|
|
||||||
val add1 = Add(Literal(1), Literal(2))
|
|
||||||
val add2 = Add(Literal(2), Literal(3))
|
|
||||||
|
|
||||||
// Non parent-child expressions. Don't sort on them.
|
|
||||||
val exprs = Seq(add2, add1, add2, add1, add2, add1)
|
|
||||||
assert(exprs.sorted(exprOrdering) === exprs)
|
|
||||||
|
|
||||||
val conditions = (GreaterThan(add1, Literal(3)), add1) ::
|
|
||||||
(GreaterThan(add2, Literal(4)), add1) ::
|
|
||||||
(GreaterThan(add2, Literal(5)), add1) :: Nil
|
|
||||||
|
|
||||||
// `caseWhenExpr` contains add1, add2.
|
|
||||||
val caseWhenExpr = CaseWhen(conditions, None)
|
|
||||||
val exprs2 = Seq(caseWhenExpr, add2, add1, add2, add1, add2, add1)
|
|
||||||
assert(exprs2.sorted(exprOrdering) ===
|
|
||||||
Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-35829: SubExprEliminationState keeps children sub exprs") {
|
test("SPARK-35829: SubExprEliminationState keeps children sub exprs") {
|
||||||
|
@ -400,8 +378,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
|
||||||
val ctx = new CodegenContext()
|
val ctx = new CodegenContext()
|
||||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
|
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
|
||||||
|
|
||||||
val add2State = subExprs.states(add2)
|
val add2State = subExprs.states(ExpressionEquals(add2))
|
||||||
val add1State = subExprs.states(add1)
|
val add1State = subExprs.states(ExpressionEquals(add1))
|
||||||
assert(add2State.children.contains(add1State))
|
assert(add2State.children.contains(add1State))
|
||||||
|
|
||||||
subExprs.states.values.foreach { state =>
|
subExprs.states.values.foreach { state =>
|
||||||
|
|
|
@ -257,7 +257,7 @@ case class HashAggregateExec(
|
||||||
aggNames: Seq[String],
|
aggNames: Seq[String],
|
||||||
aggBufferUpdatingExprs: Seq[Seq[Expression]],
|
aggBufferUpdatingExprs: Seq[Seq[Expression]],
|
||||||
aggCodeBlocks: Seq[Block],
|
aggCodeBlocks: Seq[Block],
|
||||||
subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
|
subExprs: Map[ExpressionEquals, SubExprEliminationState]): Option[Seq[String]] = {
|
||||||
val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
|
val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
|
||||||
s.eval.value :: s.eval.isNull :: Nil
|
s.eval.value :: s.eval.isNull :: Nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue