[SPARK-33092][SQL] Support subexpression elimination in ProjectExec

### What changes were proposed in this pull request?

This patch proposes to add subexpression elimination support into `ProjectExec`. It can be controlled by `spark.sql.subexpressionElimination.enabled` config.

Before this change:

```scala
val df = spark.read.option("header", true).csv("/tmp/test.csv")
 df.withColumn("my_map", expr("str_to_map(foo, '&', '=')")).select(col("my_map")("foo"), col("my_map")("bar"), col("my_map")("baz")).debugCodegen
```

L27-40: first `str_to_map`.
L68:81: second `str_to_map`.
L109-122: third `str_to_map`.

```
/* 024 */   private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException {
/* 025 */     boolean project_isNull_0 = true;
/* 026 */     UTF8String project_value_0 = null;
/* 027 */     boolean project_isNull_1 = true;
/* 028 */     MapData project_value_1 = null;
/* 029 */
/* 030 */     if (!project_exprIsNull_0_0) {
/* 031 */       project_isNull_1 = false; // resultCode could change nullability.
/* 032 */
/* 033 */       UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1);
/* 034 */       for(UTF8String kvEntry: project_kvs_0) {
/* 035 */         UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2);
/* 036 */         ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null);
/* 037 */       }
/* 038 */       project_value_1 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build();
/* 039 */
/* 040 */     }
/* 041 */     if (!project_isNull_1) {
/* 042 */       project_isNull_0 = false; // resultCode could change nullability.
/* 043 */
/* 044 */       final int project_length_0 = project_value_1.numElements();
/* 045 */       final ArrayData project_keys_0 = project_value_1.keyArray();
/* 046 */       final ArrayData project_values_0 = project_value_1.valueArray();
/* 047 */
/* 048 */       int project_index_0 = 0;
/* 049 */       boolean project_found_0 = false;
/* 050 */       while (project_index_0 < project_length_0 && !project_found_0) {
/* 051 */         final UTF8String project_key_0 = project_keys_0.getUTF8String(project_index_0);
/* 052 */         if (project_key_0.equals(((UTF8String) references[3] /* literal */))) {
/* 053 */           project_found_0 = true;
/* 054 */         } else {
/* 055 */           project_index_0++;
/* 056 */         }
/* 057 */       }
/* 058 */
/* 059 */       if (!project_found_0 || project_values_0.isNullAt(project_index_0)) {
/* 060 */         project_isNull_0 = true;
/* 061 */       } else {
/* 062 */         project_value_0 = project_values_0.getUTF8String(project_index_0);
/* 063 */       }
/* 064 */
/* 065 */     }
/* 066 */     boolean project_isNull_6 = true;
/* 067 */     UTF8String project_value_6 = null;
/* 068 */     boolean project_isNull_7 = true;
/* 069 */     MapData project_value_7 = null;
/* 070 */
/* 071 */     if (!project_exprIsNull_0_0) {
/* 072 */       project_isNull_7 = false; // resultCode could change nullability.
/* 073 */
/* 074 */       UTF8String[] project_kvs_1 = project_expr_0_0.split(((UTF8String) references[5] /* literal */), -1);
/* 075 */       for(UTF8String kvEntry: project_kvs_1) {
/* 076 */         UTF8String[] kv = kvEntry.split(((UTF8String) references[6] /* literal */), 2);
/* 077 */         ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[4] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null);
/* 078 */       }
/* 079 */       project_value_7 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[4] /* mapBuilder */).build();
/* 080 */
/* 081 */     }
/* 082 */     if (!project_isNull_7) {
/* 083 */       project_isNull_6 = false; // resultCode could change nullability.
/* 084 */
/* 085 */       final int project_length_1 = project_value_7.numElements();
/* 086 */       final ArrayData project_keys_1 = project_value_7.keyArray();
/* 087 */       final ArrayData project_values_1 = project_value_7.valueArray();
/* 088 */
/* 089 */       int project_index_1 = 0;
/* 090 */       boolean project_found_1 = false;
/* 091 */       while (project_index_1 < project_length_1 && !project_found_1) {
/* 092 */         final UTF8String project_key_1 = project_keys_1.getUTF8String(project_index_1);
/* 093 */         if (project_key_1.equals(((UTF8String) references[7] /* literal */))) {
/* 094 */           project_found_1 = true;
/* 095 */         } else {
/* 096 */           project_index_1++;
/* 097 */         }
/* 098 */       }
/* 099 */
/* 100 */       if (!project_found_1 || project_values_1.isNullAt(project_index_1)) {
/* 101 */         project_isNull_6 = true;
/* 102 */       } else {
/* 103 */         project_value_6 = project_values_1.getUTF8String(project_index_1);
/* 104 */       }
/* 105 */
/* 106 */     }
/* 107 */     boolean project_isNull_12 = true;
/* 108 */     UTF8String project_value_12 = null;
/* 109 */     boolean project_isNull_13 = true;
/* 110 */     MapData project_value_13 = null;
/* 111 */
/* 112 */     if (!project_exprIsNull_0_0) {
/* 113 */       project_isNull_13 = false; // resultCode could change nullability.
/* 114 */
/* 115 */       UTF8String[] project_kvs_2 = project_expr_0_0.split(((UTF8String) references[9] /* literal */), -1);
/* 116 */       for(UTF8String kvEntry: project_kvs_2) {
/* 117 */         UTF8String[] kv = kvEntry.split(((UTF8String) references[10] /* literal */), 2);
/* 118 */         ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[8] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null);
/* 119 */       }
/* 120 */       project_value_13 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[8] /* mapBuilder */).build();
/* 121 */
/* 122 */     }
...
```
After this change:

L27-40 evaluates the common map variable.

```
/* 024 */   private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException {
/* 025 */     // common sub-expressions
/* 026 */
/* 027 */     boolean project_isNull_0 = true;
/* 028 */     MapData project_value_0 = null;
/* 029 */
/* 030 */     if (!project_exprIsNull_0_0) {
/* 031 */       project_isNull_0 = false; // resultCode could change nullability.
/* 032 */
/* 033 */       UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1);
/* 034 */       for(UTF8String kvEntry: project_kvs_0) {
/* 035 */         UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2);
/* 036 */         ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null);
/* 037 */       }
/* 038 */       project_value_0 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build();
/* 039 */
/* 040 */     }
/* 041 */
/* 042 */     boolean project_isNull_4 = true;
/* 043 */     UTF8String project_value_4 = null;
/* 044 */
/* 045 */     if (!project_isNull_0) {
/* 046 */       project_isNull_4 = false; // resultCode could change nullability.
/* 047 */
/* 048 */       final int project_length_0 = project_value_0.numElements();
/* 049 */       final ArrayData project_keys_0 = project_value_0.keyArray();
/* 050 */       final ArrayData project_values_0 = project_value_0.valueArray();
/* 051 */
/* 052 */       int project_index_0 = 0;
/* 053 */       boolean project_found_0 = false;
/* 054 */       while (project_index_0 < project_length_0 && !project_found_0) {
/* 055 */         final UTF8String project_key_0 = project_keys_0.getUTF8String(project_index_0);
/* 056 */         if (project_key_0.equals(((UTF8String) references[3] /* literal */))) {
/* 057 */           project_found_0 = true;
/* 058 */         } else {
/* 059 */           project_index_0++;
/* 060 */         }
/* 061 */       }
/* 062 */
/* 063 */       if (!project_found_0 || project_values_0.isNullAt(project_index_0)) {
/* 064 */         project_isNull_4 = true;
/* 065 */       } else {
/* 066 */         project_value_4 = project_values_0.getUTF8String(project_index_0);
/* 067 */       }
/* 068 */
/* 069 */     }
/* 070 */     boolean project_isNull_6 = true;
/* 071 */     UTF8String project_value_6 = null;
/* 072 */
/* 073 */     if (!project_isNull_0) {
/* 074 */       project_isNull_6 = false; // resultCode could change nullability.
/* 075 */
/* 076 */       final int project_length_1 = project_value_0.numElements();
/* 077 */       final ArrayData project_keys_1 = project_value_0.keyArray();
/* 078 */       final ArrayData project_values_1 = project_value_0.valueArray();
/* 079 */
/* 080 */       int project_index_1 = 0;
/* 081 */       boolean project_found_1 = false;
/* 082 */       while (project_index_1 < project_length_1 && !project_found_1) {
/* 083 */         final UTF8String project_key_1 = project_keys_1.getUTF8String(project_index_1);
/* 084 */         if (project_key_1.equals(((UTF8String) references[4] /* literal */))) {
/* 085 */           project_found_1 = true;
/* 086 */         } else {
/* 087 */           project_index_1++;
/* 088 */         }
/* 089 */       }
/* 090 */
/* 091 */       if (!project_found_1 || project_values_1.isNullAt(project_index_1)) {
/* 092 */         project_isNull_6 = true;
/* 093 */       } else {
/* 094 */         project_value_6 = project_values_1.getUTF8String(project_index_1);
/* 095 */       }
/* 096 */
/* 097 */     }
/* 098 */     boolean project_isNull_8 = true;
/* 099 */     UTF8String project_value_8 = null;
/* 100 */
...
```

When the code is split into separated method:

```
/* 026 */   private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException {
/* 027 */     // common sub-expressions
/* 028 */
/* 029 */     MapData project_subExprValue_0 = project_subExpr_0(project_exprIsNull_0_0, project_expr_0_0);
/* 030 */
...
/* 140 */   private MapData project_subExpr_0(boolean project_exprIsNull_0_0, org.apache.spark.unsafe.types.UTF8String project_expr_0_0) {
/* 141 */     boolean project_isNull_0 = true;
/* 142 */     MapData project_value_0 = null;
/* 143 */
/* 144 */     if (!project_exprIsNull_0_0) {
/* 145 */       project_isNull_0 = false; // resultCode could change nullability.
/* 146 */
/* 147 */       UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1);
/* 148 */       for(UTF8String kvEntry: project_kvs_0) {
/* 149 */         UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2);
/* 150 */         ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null);
/* 151 */       }
/* 152 */       project_value_0 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build();
/* 153 */
/* 154 */     }
/* 155 */     project_subExprIsNull_0 = project_isNull_0;
/* 156 */     return project_value_0;
/* 157 */   }
```

### Why are the changes needed?

Users occasionally write repeated expression in projection. It is also possibly that query optimizer optimizes a query to evaluate same expression many times in a Project. Currently in ProjectExec, we don't support subexpression elimination in Whole-stage codegen. We can support it to reduce redundant evaluation.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

`spark.sql.subexpressionElimination.enabled` is enabled by default. So that's said we should pass all tests with this change.

Closes #29975 from viirya/SPARK-33092.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Liang-Chi Hsieh 2020-10-12 16:54:21 +09:00 committed by Takeshi Yamamuro
parent 9896288b88
commit 78c0967bbe
4 changed files with 56 additions and 18 deletions

View file

@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before
* calling common subexpressions.
*/
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
case class SubExprCodes(
codes: Seq[String],
states: Map[Expression, SubExprEliminationState],
exprCodesNeedEvaluate: Seq[ExprCode])
/**
* The main information about a new added function.
@ -1044,7 +1049,7 @@ class CodegenContext extends Logging {
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val commonExprVals = commonExprs.map(_.head.genCode(this))
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))
lazy val nonSplitExprCode = {
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
@ -1055,10 +1060,17 @@ class CodegenContext extends Logging {
}
}
val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) {
val inputVarsForAllFuncs = commonExprs.map { expr =>
getLocalInputVariableValues(this, expr.head).toSeq
}
// For some operators, they do not require all its child's outputs to be evaluated in advance.
// Instead it only early evaluates part of outputs, for example, `ProjectExec` only early
// evaluate the outputs used more than twice. So we need to extract these variables used by
// subexpressions and evaluate them before subexpressions.
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head)
(inputVars.toSeq, exprCodes.toSeq)
}.unzip
val splitThreshold = SQLConf.get.methodSplitThreshold
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) {
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
commonExprs.zipWithIndex.map { case (exprs, i) =>
val expr = exprs.head
@ -1109,7 +1121,7 @@ class CodegenContext extends Logging {
} else {
nonSplitExprCode
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten)
}
/**
@ -1732,15 +1744,23 @@ object CodeGenerator extends Logging {
}
/**
* Extracts all the input variables from references and subexpression elimination states
* for a given `expr`. This result will be used to split the generated code of
* expressions into multiple functions.
* This methods returns two values in a Tuple.
*
* First value: Extracts all the input variables from references and subexpression
* elimination states for a given `expr`. This result will be used to split the
* generated code of expressions into multiple functions.
*
* Second value: Returns the set of `ExprCodes`s which are necessary codes before
* evaluating subexpressions.
*/
def getLocalInputVariableValues(
ctx: CodegenContext,
expr: Expression,
subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = {
subExprs: Map[Expression, SubExprEliminationState] = Map.empty)
: (Set[VariableValue], Set[ExprCode]) = {
val argSet = mutable.Set[VariableValue]()
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
}
@ -1761,16 +1781,21 @@ object CodeGenerator extends Logging {
case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null =>
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
collectLocalVariable(value)
collectLocalVariable(isNull)
val exprCode = ctx.currentVars(ref.ordinal)
// If the referred variable is not evaluated yet.
if (exprCode.code != EmptyBlock) {
exprCodesNeedEvaluate += exprCode.copy()
exprCode.code = EmptyBlock
}
collectLocalVariable(exprCode.value)
collectLocalVariable(exprCode.isNull)
case e =>
stack.pushAll(e.children)
}
}
argSet.toSet
(argSet.toSet, exprCodesNeedEvaluate.toSet)
}
/**

View file

@ -263,7 +263,7 @@ case class HashAggregateExec(
} else {
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
val inputVarsForOneFunc = aggExprsForOneFunc.map(
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit

View file

@ -66,10 +66,23 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) {
// subexpression elimination
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate)
} else {
("", exprs.map(_.genCode(ctx)), Seq.empty)
}
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
s"""
|// common sub-expressions
|${evaluateVariables(localValInputs)}
|$subExprsCode
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
|${consume(ctx, resultVars)}
""".stripMargin

View file

@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
// this input data will fail to read middle way.
val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j)
val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j)
val e3 = intercept[SparkException] {
input.write.format(cls.getName).option("path", path).mode("overwrite").save()
}