[SPARK-33092][SQL] Support subexpression elimination in ProjectExec
### What changes were proposed in this pull request? This patch proposes to add subexpression elimination support into `ProjectExec`. It can be controlled by `spark.sql.subexpressionElimination.enabled` config. Before this change: ```scala val df = spark.read.option("header", true).csv("/tmp/test.csv") df.withColumn("my_map", expr("str_to_map(foo, '&', '=')")).select(col("my_map")("foo"), col("my_map")("bar"), col("my_map")("baz")).debugCodegen ``` L27-40: first `str_to_map`. L68:81: second `str_to_map`. L109-122: third `str_to_map`. ``` /* 024 */ private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException { /* 025 */ boolean project_isNull_0 = true; /* 026 */ UTF8String project_value_0 = null; /* 027 */ boolean project_isNull_1 = true; /* 028 */ MapData project_value_1 = null; /* 029 */ /* 030 */ if (!project_exprIsNull_0_0) { /* 031 */ project_isNull_1 = false; // resultCode could change nullability. /* 032 */ /* 033 */ UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1); /* 034 */ for(UTF8String kvEntry: project_kvs_0) { /* 035 */ UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2); /* 036 */ ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null); /* 037 */ } /* 038 */ project_value_1 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build(); /* 039 */ /* 040 */ } /* 041 */ if (!project_isNull_1) { /* 042 */ project_isNull_0 = false; // resultCode could change nullability. /* 043 */ /* 044 */ final int project_length_0 = project_value_1.numElements(); /* 045 */ final ArrayData project_keys_0 = project_value_1.keyArray(); /* 046 */ final ArrayData project_values_0 = project_value_1.valueArray(); /* 047 */ /* 048 */ int project_index_0 = 0; /* 049 */ boolean project_found_0 = false; /* 050 */ while (project_index_0 < project_length_0 && !project_found_0) { /* 051 */ final UTF8String project_key_0 = project_keys_0.getUTF8String(project_index_0); /* 052 */ if (project_key_0.equals(((UTF8String) references[3] /* literal */))) { /* 053 */ project_found_0 = true; /* 054 */ } else { /* 055 */ project_index_0++; /* 056 */ } /* 057 */ } /* 058 */ /* 059 */ if (!project_found_0 || project_values_0.isNullAt(project_index_0)) { /* 060 */ project_isNull_0 = true; /* 061 */ } else { /* 062 */ project_value_0 = project_values_0.getUTF8String(project_index_0); /* 063 */ } /* 064 */ /* 065 */ } /* 066 */ boolean project_isNull_6 = true; /* 067 */ UTF8String project_value_6 = null; /* 068 */ boolean project_isNull_7 = true; /* 069 */ MapData project_value_7 = null; /* 070 */ /* 071 */ if (!project_exprIsNull_0_0) { /* 072 */ project_isNull_7 = false; // resultCode could change nullability. /* 073 */ /* 074 */ UTF8String[] project_kvs_1 = project_expr_0_0.split(((UTF8String) references[5] /* literal */), -1); /* 075 */ for(UTF8String kvEntry: project_kvs_1) { /* 076 */ UTF8String[] kv = kvEntry.split(((UTF8String) references[6] /* literal */), 2); /* 077 */ ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[4] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null); /* 078 */ } /* 079 */ project_value_7 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[4] /* mapBuilder */).build(); /* 080 */ /* 081 */ } /* 082 */ if (!project_isNull_7) { /* 083 */ project_isNull_6 = false; // resultCode could change nullability. /* 084 */ /* 085 */ final int project_length_1 = project_value_7.numElements(); /* 086 */ final ArrayData project_keys_1 = project_value_7.keyArray(); /* 087 */ final ArrayData project_values_1 = project_value_7.valueArray(); /* 088 */ /* 089 */ int project_index_1 = 0; /* 090 */ boolean project_found_1 = false; /* 091 */ while (project_index_1 < project_length_1 && !project_found_1) { /* 092 */ final UTF8String project_key_1 = project_keys_1.getUTF8String(project_index_1); /* 093 */ if (project_key_1.equals(((UTF8String) references[7] /* literal */))) { /* 094 */ project_found_1 = true; /* 095 */ } else { /* 096 */ project_index_1++; /* 097 */ } /* 098 */ } /* 099 */ /* 100 */ if (!project_found_1 || project_values_1.isNullAt(project_index_1)) { /* 101 */ project_isNull_6 = true; /* 102 */ } else { /* 103 */ project_value_6 = project_values_1.getUTF8String(project_index_1); /* 104 */ } /* 105 */ /* 106 */ } /* 107 */ boolean project_isNull_12 = true; /* 108 */ UTF8String project_value_12 = null; /* 109 */ boolean project_isNull_13 = true; /* 110 */ MapData project_value_13 = null; /* 111 */ /* 112 */ if (!project_exprIsNull_0_0) { /* 113 */ project_isNull_13 = false; // resultCode could change nullability. /* 114 */ /* 115 */ UTF8String[] project_kvs_2 = project_expr_0_0.split(((UTF8String) references[9] /* literal */), -1); /* 116 */ for(UTF8String kvEntry: project_kvs_2) { /* 117 */ UTF8String[] kv = kvEntry.split(((UTF8String) references[10] /* literal */), 2); /* 118 */ ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[8] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null); /* 119 */ } /* 120 */ project_value_13 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[8] /* mapBuilder */).build(); /* 121 */ /* 122 */ } ... ``` After this change: L27-40 evaluates the common map variable. ``` /* 024 */ private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException { /* 025 */ // common sub-expressions /* 026 */ /* 027 */ boolean project_isNull_0 = true; /* 028 */ MapData project_value_0 = null; /* 029 */ /* 030 */ if (!project_exprIsNull_0_0) { /* 031 */ project_isNull_0 = false; // resultCode could change nullability. /* 032 */ /* 033 */ UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1); /* 034 */ for(UTF8String kvEntry: project_kvs_0) { /* 035 */ UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2); /* 036 */ ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null); /* 037 */ } /* 038 */ project_value_0 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build(); /* 039 */ /* 040 */ } /* 041 */ /* 042 */ boolean project_isNull_4 = true; /* 043 */ UTF8String project_value_4 = null; /* 044 */ /* 045 */ if (!project_isNull_0) { /* 046 */ project_isNull_4 = false; // resultCode could change nullability. /* 047 */ /* 048 */ final int project_length_0 = project_value_0.numElements(); /* 049 */ final ArrayData project_keys_0 = project_value_0.keyArray(); /* 050 */ final ArrayData project_values_0 = project_value_0.valueArray(); /* 051 */ /* 052 */ int project_index_0 = 0; /* 053 */ boolean project_found_0 = false; /* 054 */ while (project_index_0 < project_length_0 && !project_found_0) { /* 055 */ final UTF8String project_key_0 = project_keys_0.getUTF8String(project_index_0); /* 056 */ if (project_key_0.equals(((UTF8String) references[3] /* literal */))) { /* 057 */ project_found_0 = true; /* 058 */ } else { /* 059 */ project_index_0++; /* 060 */ } /* 061 */ } /* 062 */ /* 063 */ if (!project_found_0 || project_values_0.isNullAt(project_index_0)) { /* 064 */ project_isNull_4 = true; /* 065 */ } else { /* 066 */ project_value_4 = project_values_0.getUTF8String(project_index_0); /* 067 */ } /* 068 */ /* 069 */ } /* 070 */ boolean project_isNull_6 = true; /* 071 */ UTF8String project_value_6 = null; /* 072 */ /* 073 */ if (!project_isNull_0) { /* 074 */ project_isNull_6 = false; // resultCode could change nullability. /* 075 */ /* 076 */ final int project_length_1 = project_value_0.numElements(); /* 077 */ final ArrayData project_keys_1 = project_value_0.keyArray(); /* 078 */ final ArrayData project_values_1 = project_value_0.valueArray(); /* 079 */ /* 080 */ int project_index_1 = 0; /* 081 */ boolean project_found_1 = false; /* 082 */ while (project_index_1 < project_length_1 && !project_found_1) { /* 083 */ final UTF8String project_key_1 = project_keys_1.getUTF8String(project_index_1); /* 084 */ if (project_key_1.equals(((UTF8String) references[4] /* literal */))) { /* 085 */ project_found_1 = true; /* 086 */ } else { /* 087 */ project_index_1++; /* 088 */ } /* 089 */ } /* 090 */ /* 091 */ if (!project_found_1 || project_values_1.isNullAt(project_index_1)) { /* 092 */ project_isNull_6 = true; /* 093 */ } else { /* 094 */ project_value_6 = project_values_1.getUTF8String(project_index_1); /* 095 */ } /* 096 */ /* 097 */ } /* 098 */ boolean project_isNull_8 = true; /* 099 */ UTF8String project_value_8 = null; /* 100 */ ... ``` When the code is split into separated method: ``` /* 026 */ private void project_doConsume_0(InternalRow inputadapter_row_0, UTF8String project_expr_0_0, boolean project_exprIsNull_0_0) throws java.io.IOException { /* 027 */ // common sub-expressions /* 028 */ /* 029 */ MapData project_subExprValue_0 = project_subExpr_0(project_exprIsNull_0_0, project_expr_0_0); /* 030 */ ... /* 140 */ private MapData project_subExpr_0(boolean project_exprIsNull_0_0, org.apache.spark.unsafe.types.UTF8String project_expr_0_0) { /* 141 */ boolean project_isNull_0 = true; /* 142 */ MapData project_value_0 = null; /* 143 */ /* 144 */ if (!project_exprIsNull_0_0) { /* 145 */ project_isNull_0 = false; // resultCode could change nullability. /* 146 */ /* 147 */ UTF8String[] project_kvs_0 = project_expr_0_0.split(((UTF8String) references[1] /* literal */), -1); /* 148 */ for(UTF8String kvEntry: project_kvs_0) { /* 149 */ UTF8String[] kv = kvEntry.split(((UTF8String) references[2] /* literal */), 2); /* 150 */ ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).put(kv[0], kv.length == 2 ? kv[1] : null); /* 151 */ } /* 152 */ project_value_0 = ((org.apache.spark.sql.catalyst.util.ArrayBasedMapBuilder) references[0] /* mapBuilder */).build(); /* 153 */ /* 154 */ } /* 155 */ project_subExprIsNull_0 = project_isNull_0; /* 156 */ return project_value_0; /* 157 */ } ``` ### Why are the changes needed? Users occasionally write repeated expression in projection. It is also possibly that query optimizer optimizes a query to evaluate same expression many times in a Project. Currently in ProjectExec, we don't support subexpression elimination in Whole-stage codegen. We can support it to reduce redundant evaluation. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? `spark.sql.subexpressionElimination.enabled` is enabled by default. So that's said we should pass all tests with this change. Closes #29975 from viirya/SPARK-33092. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
9896288b88
commit
78c0967bbe
|
@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
|
|||
* @param codes Strings representing the codes that evaluate common subexpressions.
|
||||
* @param states Foreach expression that is participating in subexpression elimination,
|
||||
* the state to use.
|
||||
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before
|
||||
* calling common subexpressions.
|
||||
*/
|
||||
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
|
||||
case class SubExprCodes(
|
||||
codes: Seq[String],
|
||||
states: Map[Expression, SubExprEliminationState],
|
||||
exprCodesNeedEvaluate: Seq[ExprCode])
|
||||
|
||||
/**
|
||||
* The main information about a new added function.
|
||||
|
@ -1044,7 +1049,7 @@ class CodegenContext extends Logging {
|
|||
// Get all the expressions that appear at least twice and set up the state for subexpression
|
||||
// elimination.
|
||||
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
|
||||
val commonExprVals = commonExprs.map(_.head.genCode(this))
|
||||
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))
|
||||
|
||||
lazy val nonSplitExprCode = {
|
||||
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
|
||||
|
@ -1055,10 +1060,17 @@ class CodegenContext extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) {
|
||||
val inputVarsForAllFuncs = commonExprs.map { expr =>
|
||||
getLocalInputVariableValues(this, expr.head).toSeq
|
||||
}
|
||||
// For some operators, they do not require all its child's outputs to be evaluated in advance.
|
||||
// Instead it only early evaluates part of outputs, for example, `ProjectExec` only early
|
||||
// evaluate the outputs used more than twice. So we need to extract these variables used by
|
||||
// subexpressions and evaluate them before subexpressions.
|
||||
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
|
||||
val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head)
|
||||
(inputVars.toSeq, exprCodes.toSeq)
|
||||
}.unzip
|
||||
|
||||
val splitThreshold = SQLConf.get.methodSplitThreshold
|
||||
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) {
|
||||
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
|
||||
commonExprs.zipWithIndex.map { case (exprs, i) =>
|
||||
val expr = exprs.head
|
||||
|
@ -1109,7 +1121,7 @@ class CodegenContext extends Logging {
|
|||
} else {
|
||||
nonSplitExprCode
|
||||
}
|
||||
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
|
||||
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1732,15 +1744,23 @@ object CodeGenerator extends Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Extracts all the input variables from references and subexpression elimination states
|
||||
* for a given `expr`. This result will be used to split the generated code of
|
||||
* expressions into multiple functions.
|
||||
* This methods returns two values in a Tuple.
|
||||
*
|
||||
* First value: Extracts all the input variables from references and subexpression
|
||||
* elimination states for a given `expr`. This result will be used to split the
|
||||
* generated code of expressions into multiple functions.
|
||||
*
|
||||
* Second value: Returns the set of `ExprCodes`s which are necessary codes before
|
||||
* evaluating subexpressions.
|
||||
*/
|
||||
def getLocalInputVariableValues(
|
||||
ctx: CodegenContext,
|
||||
expr: Expression,
|
||||
subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = {
|
||||
subExprs: Map[Expression, SubExprEliminationState] = Map.empty)
|
||||
: (Set[VariableValue], Set[ExprCode]) = {
|
||||
val argSet = mutable.Set[VariableValue]()
|
||||
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
|
||||
|
||||
if (ctx.INPUT_ROW != null) {
|
||||
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
|
||||
}
|
||||
|
@ -1761,16 +1781,21 @@ object CodeGenerator extends Logging {
|
|||
|
||||
case ref: BoundReference if ctx.currentVars != null &&
|
||||
ctx.currentVars(ref.ordinal) != null =>
|
||||
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
|
||||
collectLocalVariable(value)
|
||||
collectLocalVariable(isNull)
|
||||
val exprCode = ctx.currentVars(ref.ordinal)
|
||||
// If the referred variable is not evaluated yet.
|
||||
if (exprCode.code != EmptyBlock) {
|
||||
exprCodesNeedEvaluate += exprCode.copy()
|
||||
exprCode.code = EmptyBlock
|
||||
}
|
||||
collectLocalVariable(exprCode.value)
|
||||
collectLocalVariable(exprCode.isNull)
|
||||
|
||||
case e =>
|
||||
stack.pushAll(e.children)
|
||||
}
|
||||
}
|
||||
|
||||
argSet.toSet
|
||||
(argSet.toSet, exprCodesNeedEvaluate.toSet)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -263,7 +263,7 @@ case class HashAggregateExec(
|
|||
} else {
|
||||
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
|
||||
val inputVarsForOneFunc = aggExprsForOneFunc.map(
|
||||
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
|
||||
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq
|
||||
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
|
||||
|
||||
// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
|
||||
|
|
|
@ -66,10 +66,23 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
|
|||
|
||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
|
||||
val exprs = bindReferences[Expression](projectList, child.output)
|
||||
val resultVars = exprs.map(_.genCode(ctx))
|
||||
val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) {
|
||||
// subexpression elimination
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
|
||||
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
exprs.map(_.genCode(ctx))
|
||||
}
|
||||
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate)
|
||||
} else {
|
||||
("", exprs.map(_.genCode(ctx)), Seq.empty)
|
||||
}
|
||||
|
||||
// Evaluation of non-deterministic expressions can't be deferred.
|
||||
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
|
||||
s"""
|
||||
|// common sub-expressions
|
||||
|${evaluateVariables(localValInputs)}
|
||||
|$subExprsCode
|
||||
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
|
||||
|${consume(ctx, resultVars)}
|
||||
""".stripMargin
|
||||
|
|
|
@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
|
|||
}
|
||||
}
|
||||
// this input data will fail to read middle way.
|
||||
val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j)
|
||||
val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j)
|
||||
val e3 = intercept[SparkException] {
|
||||
input.write.format(cls.getName).option("path", path).mode("overwrite").save()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue