[SPARK-35829][SQL] Clean up evaluates subexpressions and add more flexibility to evaluate particular subexpressoin

### What changes were proposed in this pull request?

This patch refactors the evaluation of subexpressions.

There are two changes:

1. Clean up subexpression code after evaluation to avoid duplicate evaluation.
2. Evaluate all children subexpressions when evaluating a subexpression.

### Why are the changes needed?

Currently `subexpressionEliminationForWholeStageCodegen` return the gen-ed code of subexpressions. The caller simply puts the code into its code block. We need more flexible evaluation here. For example, for Filter operator's subexpression evaluation, we may need to evaluate particular subexpression for one predicate. Current approach cannot satisfy the requirement.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests.

Closes #32980 from viirya/subexpr-eval.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Liang-Chi Hsieh 2021-06-29 22:14:37 -07:00
parent 24b67ca9a8
commit 064230de97
6 changed files with 149 additions and 44 deletions

View file

@ -139,7 +139,10 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState => ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added // This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it. // as a function before. In that case, we just re-use it.
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) ExprCode(
ctx.registerComment(this.toString),
subExprState.eval.isNull,
subExprState.eval.value)
}.getOrElse { }.getOrElse {
val isNull = ctx.freshName("isNull") val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value") val value = ctx.freshName("value")

View file

@ -76,24 +76,38 @@ object ExprCode {
/** /**
* State used for subexpression elimination. * State used for subexpression elimination.
* *
* @param isNull A term that holds a boolean value representing whether the expression evaluated * @param eval The source code for evaluating the subexpression.
* to null. * @param children The sequence of subexpressions as the children expressions. Before
* @param value A term for a value of a common sub-expression. Not valid if `isNull` * evaluating this subexpression, we should evaluate all children
* is set to `true`. * subexpressions first. This is used if we want to selectively evaluate
* particular subexpressions, instead of all at once. In the case, we need
* to make sure we evaluate all children subexpressions too.
*/ */
case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) case class SubExprEliminationState(
eval: ExprCode,
children: Seq[SubExprEliminationState])
object SubExprEliminationState {
def apply(eval: ExprCode): SubExprEliminationState = {
new SubExprEliminationState(eval, Seq.empty)
}
def apply(
eval: ExprCode,
children: Seq[SubExprEliminationState]): SubExprEliminationState = {
new SubExprEliminationState(eval, children.reverse)
}
}
/** /**
* Codes and common subexpressions mapping used for subexpression elimination. * Codes and common subexpressions mapping used for subexpression elimination.
* *
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination, * @param states Foreach expression that is participating in subexpression elimination,
* the state to use. * the state to use.
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before * @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before
* calling common subexpressions. * calling common subexpressions.
*/ */
case class SubExprCodes( case class SubExprCodes(
codes: Seq[String],
states: Map[Expression, SubExprEliminationState], states: Map[Expression, SubExprEliminationState],
exprCodesNeedEvaluate: Seq[ExprCode]) exprCodesNeedEvaluate: Seq[ExprCode])
@ -1030,11 +1044,55 @@ class CodegenContext extends Logging {
} }
/** /**
* Checks and sets up the state and codegen for subexpression elimination. This finds the * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After
* common subexpressions, generates the code snippets that evaluate those expressions and * evaluating a subexpression, this method will clean up the code block to avoid duplicate
* populates the mapping of common subexpressions to the generated code snippets. The generated * evaluation.
* code snippets will be returned and should be inserted into generated codes before these */
* common subexpressions actually are used first time. def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = {
val code = new StringBuilder()
subExprStates.foreach { state =>
val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code
code.append(currentCode + "\n")
state.eval.code = EmptyBlock
}
code.toString()
}
/**
* Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen.
*
* This finds the common subexpressions, generates the code snippets that evaluate those
* expressions and populates the mapping of common subexpressions to the generated code snippets.
*
* The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which
* contains an `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode`
* includes java source code, result variable name and is-null variable name of the subexpression.
*
* Besides, this also returns a sequences of `ExprCode` which are expression codes that need to
* be evaluated (as their input parameters) before evaluating subexpressions.
*
* To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with
* the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup
* the states to avoid duplicate evaluation.
*
* The details of subexpression generation:
* 1. Gets subexpression set. See `EquivalentExpressions`.
* 2. Generate code of subexpressions as a whole block of code (non-split case)
* 3. Check if the total length of the above block is larger than the split-threshold. If so,
* try to split it in step 4, otherwise returning the non-split code block.
* 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so,
* try to split, otherwise returning the non-split code block.
* 5. For each subexpression, generating a function and put the code into it. To evaluate the
* subexpression, just call the function.
*
* The explanation of subexpression codegen:
* 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each
* subexpression may depends on other subexpressions (children). So when generating code
* for subexpressions, we iterate over each subexpression and put the mapping between
* (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
* evaluation, we can look for generated subexpressions and do replacement.
*/ */
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping // Create a clear EquivalentExpressions and SubExprEliminationState mapping
@ -1049,17 +1107,25 @@ class CodegenContext extends Logging {
// elimination. // elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
val nonSplitExprCode = { val nonSplitCode = {
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
commonExprs.map { exprs => commonExprs.map { exprs =>
val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
val eval = exprs.head.genCode(this) val eval = exprs.head.genCode(this)
// Generate the code for this expression tree. // Collects other subexpressions from the children.
val state = SubExprEliminationState(eval.isNull, eval.value) val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
exprs.head.foreach {
case e if subExprEliminationExprs.contains(e) =>
childrenSubExprs += subExprEliminationExprs(e)
case _ =>
}
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state)) exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
allStates += state
Seq(eval) Seq(eval)
}.head }
eval.code.toString
} }
allStates.toSeq
} }
// For some operators, they do not require all its child's outputs to be evaluated in advance. // For some operators, they do not require all its child's outputs to be evaluated in advance.
@ -1071,14 +1137,13 @@ class CodegenContext extends Logging {
(inputVars.toSeq, exprCodes.toSeq) (inputVars.toSeq, exprCodes.toSeq)
}.unzip }.unzip
val splitThreshold = SQLConf.get.methodSplitThreshold val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold
val (subExprsMap, exprCodes) = if (needSplit) {
val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) {
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
val localSubExprEliminationExprs = val localSubExprEliminationExprs =
mutable.HashMap.empty[Expression, SubExprEliminationState] mutable.HashMap.empty[Expression, SubExprEliminationState]
val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) => commonExprs.zipWithIndex.foreach { case (exprs, i) =>
val expr = exprs.head val expr = exprs.head
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
Seq(expr.genCode(this)) Seq(expr.genCode(this))
@ -1111,24 +1176,34 @@ class CodegenContext extends Logging {
|} |}
""".stripMargin """.stripMargin
val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType)) // Collects other subexpressions from the children.
exprs.foreach(localSubExprEliminationExprs.put(_, state)) val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
exprs.head.foreach {
case e if localSubExprEliminationExprs.contains(e) =>
childrenSubExprs += localSubExprEliminationExprs(e)
case _ =>
}
val inputVariables = inputVars.map(_.variableName).mkString(", ") val inputVariables = inputVars.map(_.variableName).mkString(", ")
s"${addNewFunction(fnName, fn)}($inputVariables);" val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
val state = SubExprEliminationState(
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
childrenSubExprs.toSeq)
exprs.foreach(localSubExprEliminationExprs.put(_, state))
} }
(splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate) (localSubExprEliminationExprs, exprCodesNeedEvaluate)
} else { } else {
if (Utils.isTesting) { if (Utils.isTesting) {
throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH) throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH)
} else { } else {
logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH)) logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH))
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) (localSubExprEliminationExprsForNonSplit, Seq.empty)
} }
} }
} else { } else {
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) (localSubExprEliminationExprsForNonSplit, Seq.empty)
} }
SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten) SubExprCodes(subExprsMap.toMap, exprCodes.flatten)
} }
/** /**
@ -1174,10 +1249,12 @@ class CodegenContext extends Logging {
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low. // at least two nodes) as the cost of doing it is expected to be low.
val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState( val state = SubExprEliminationState(
JavaCode.isNullGlobal(isNull), ExprCode(code"$subExprCode",
JavaCode.global(value, expr.dataType)) JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType)))
subExprEliminationExprs ++= e.map(_ -> state).toMap subExprEliminationExprs ++= e.map(_ -> state).toMap
} }
} }
@ -1776,9 +1853,8 @@ object CodeGenerator extends Logging {
while (stack.nonEmpty) { while (stack.nonEmpty) {
stack.pop() match { stack.pop() match {
case e if subExprs.contains(e) => case e if subExprs.contains(e) =>
val SubExprEliminationState(isNull, value) = subExprs(e) collectLocalVariable(subExprs(e).eval.value)
collectLocalVariable(value) collectLocalVariable(subExprs(e).eval.isNull)
collectLocalVariable(isNull)
case ref: BoundReference if ctx.currentVars != null && case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null => ctx.currentVars(ref.ordinal) != null =>

View file

@ -463,15 +463,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val add1 = Add(ref, ref) val add1 = Add(ref, ref)
val add2 = Add(add1, add1) val add2 = Add(add1, add1)
val dummy = SubExprEliminationState( val dummy = SubExprEliminationState(
JavaCode.variable("dummy", BooleanType), ExprCode(EmptyBlock,
JavaCode.variable("dummy", BooleanType)) JavaCode.variable("dummy", BooleanType),
JavaCode.variable("dummy", BooleanType)))
// raw testing of basic functionality // raw testing of basic functionality
{ {
val ctx = new CodegenContext val ctx = new CodegenContext
val e = ref.genCode(ctx) val e = ref.genCode(ctx)
// before // before
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) ctx.subExprEliminationExprs += ref -> SubExprEliminationState(
ExprCode(EmptyBlock, e.isNull, e.value))
assert(ctx.subExprEliminationExprs.contains(ref)) assert(ctx.subExprEliminationExprs.contains(ref))
// call withSubExprEliminationExprs // call withSubExprEliminationExprs
ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {

View file

@ -282,7 +282,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
ctx.withSubExprEliminationExprs(subExprs.states) { ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx)) exprs.map(_.genCode(ctx))
} }
val subExprsCode = subExprs.codes.mkString("\n") val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val codeBody = s""" val codeBody = s"""
public java.lang.Object generate(Object[] references) { public java.lang.Object generate(Object[] references) {
@ -392,6 +392,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr)) Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
} }
test("SPARK-35829: SubExprEliminationState keeps children sub exprs") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(add1, add1)
val exprs = Seq(add1, add1, add2, add2)
val ctx = new CodegenContext()
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
val add2State = subExprs.states(add2)
val add1State = subExprs.states(add1)
assert(add2State.children.contains(add1State))
subExprs.states.values.foreach { state =>
assert(state.eval.code != EmptyBlock)
}
ctx.evaluateSubExprEliminationState(subExprs.states.values)
subExprs.states.values.foreach { state =>
assert(state.eval.code == EmptyBlock)
}
}
test("SPARK-35886: PromotePrecision should not overwrite genCode") { test("SPARK-35886: PromotePrecision should not overwrite genCode") {
val p = PromotePrecision(Literal(Decimal("10.1"))) val p = PromotePrecision(Literal(Decimal("10.1")))

View file

@ -258,7 +258,9 @@ case class HashAggregateExec(
aggBufferUpdatingExprs: Seq[Seq[Expression]], aggBufferUpdatingExprs: Seq[Seq[Expression]],
aggCodeBlocks: Seq[Block], aggCodeBlocks: Seq[Block],
subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = { subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
s.eval.value :: s.eval.isNull :: Nil
}
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
// `SimpleExprValue`s cannot be used as an input variable for split functions, so // `SimpleExprValue`s cannot be used as an input variable for split functions, so
// we give up splitting functions if it exists in `subExprs`. // we give up splitting functions if it exists in `subExprs`.
@ -363,7 +365,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs) bindReferences(updateExprsForOneFunc, inputAttrs)
} }
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n") val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) { ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx)) boundUpdateExprsForOneFunc.map(_.genCode(ctx))
@ -989,7 +991,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs) bindReferences(updateExprsForOneFunc, inputAttrs)
} }
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n") val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) { ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx)) boundUpdateExprsForOneFunc.map(_.genCode(ctx))
@ -1035,7 +1037,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs) bindReferences(updateExprsForOneFunc, inputAttrs)
} }
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n") val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) { ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx)) boundUpdateExprsForOneFunc.map(_.genCode(ctx))

View file

@ -72,7 +72,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx)) exprs.map(_.genCode(ctx))
} }
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate) (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars,
subExprs.exprCodesNeedEvaluate)
} else { } else {
("", exprs.map(_.genCode(ctx)), Seq.empty) ("", exprs.map(_.genCode(ctx)), Seq.empty)
} }