diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 71d3673346..9a26c388f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) * @param codes Strings representing the codes that evaluate common subexpressions. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. + * @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before + * calling common subexpressions. */ -case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +case class SubExprCodes( + codes: Seq[String], + states: Map[Expression, SubExprEliminationState], + exprCodesNeedEvaluate: Seq[ExprCode]) /** * The main information about a new added function. @@ -1044,7 +1049,7 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - val commonExprVals = commonExprs.map(_.head.genCode(this)) + lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) lazy val nonSplitExprCode = { commonExprs.zip(commonExprVals).map { case (exprs, eval) => @@ -1055,10 +1060,17 @@ class CodegenContext extends Logging { } } - val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) { - val inputVarsForAllFuncs = commonExprs.map { expr => - getLocalInputVariableValues(this, expr.head).toSeq - } + // For some operators, they do not require all its child's outputs to be evaluated in advance. + // Instead it only early evaluates part of outputs, for example, `ProjectExec` only early + // evaluate the outputs used more than twice. So we need to extract these variables used by + // subexpressions and evaluate them before subexpressions. + val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => + val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head) + (inputVars.toSeq, exprCodes.toSeq) + }.unzip + + val splitThreshold = SQLConf.get.methodSplitThreshold + val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { commonExprs.zipWithIndex.map { case (exprs, i) => val expr = exprs.head @@ -1109,7 +1121,7 @@ class CodegenContext extends Logging { } else { nonSplitExprCode } - SubExprCodes(codes, localSubExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten) } /** @@ -1732,15 +1744,23 @@ object CodeGenerator extends Logging { } /** - * Extracts all the input variables from references and subexpression elimination states - * for a given `expr`. This result will be used to split the generated code of - * expressions into multiple functions. + * This methods returns two values in a Tuple. + * + * First value: Extracts all the input variables from references and subexpression + * elimination states for a given `expr`. This result will be used to split the + * generated code of expressions into multiple functions. + * + * Second value: Returns the set of `ExprCodes`s which are necessary codes before + * evaluating subexpressions. */ def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = { + subExprs: Map[Expression, SubExprEliminationState] = Map.empty) + : (Set[VariableValue], Set[ExprCode]) = { val argSet = mutable.Set[VariableValue]() + val exprCodesNeedEvaluate = mutable.Set[ExprCode]() + if (ctx.INPUT_ROW != null) { argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) } @@ -1761,16 +1781,21 @@ object CodeGenerator extends Logging { case ref: BoundReference if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => - val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) - collectLocalVariable(value) - collectLocalVariable(isNull) + val exprCode = ctx.currentVars(ref.ordinal) + // If the referred variable is not evaluated yet. + if (exprCode.code != EmptyBlock) { + exprCodesNeedEvaluate += exprCode.copy() + exprCode.code = EmptyBlock + } + collectLocalVariable(exprCode.value) + collectLocalVariable(exprCode.isNull) case e => stack.pushAll(e.children) } } - argSet.toSet + (argSet.toSet, exprCodesNeedEvaluate.toSet) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index dcb465707a..52d0450afb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -263,7 +263,7 @@ case class HashAggregateExec( } else { val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1f70fde3f7..7334ea1e27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -66,10 +66,23 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val resultVars = exprs.map(_.genCode(ctx)) + val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { + // subexpression elimination + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { + exprs.map(_.genCode(ctx)) + } + (subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate) + } else { + ("", exprs.map(_.genCode(ctx)), Seq.empty) + } + // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" + |// common sub-expressions + |${evaluateVariables(localValInputs)} + |$subExprsCode |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a9c521eb46..ec1ac00d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) + val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() }