diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4c1bfcfdf7..660a1dbaf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -403,13 +403,14 @@ class CodegenContext { * equivalentExpressions will match the tree containing `col1 + col2` and it will only * be evaluated once. */ - val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] + // Visible for testing. + private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. - val subexprFunctions = mutable.ArrayBuffer.empty[String] + private val subexprFunctions = mutable.ArrayBuffer.empty[String] val outerClassName = "OuterClass" @@ -993,6 +994,15 @@ class CodegenContext { } } + /** + * Returns the code for subexpression elimination after splitting it if necessary. + */ + def subexprFunctionsCode: String = { + // Whole-stage codegen's subexpression elimination is handled in another code path + assert(currentVars == null || subexprFunctions.isEmpty) + splitExpressions(subexprFunctions, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) + } + /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 838bd1c679..2e018de071 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -92,7 +92,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP } // Evaluate all the subexpressions. - val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val evalSubexpr = ctx.subexprFunctionsCode val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index fb1d8a3c8e..8da7f65bde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -299,7 +299,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val evalSubexpr = ctx.subexprFunctionsCode val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 4e64313da1..28d2607e6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -545,6 +545,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } assert(appender.seenMessage) } + + test("SPARK-28916: subexrepssion elimination can cause 64kb code limit on UnsafeProjection") { + val numOfExprs = 10000 + val exprs = (0 to numOfExprs).flatMap(colIndex => + Seq(Add(BoundReference(colIndex, DoubleType, true), + BoundReference(numOfExprs + colIndex, DoubleType, true)), + Add(BoundReference(colIndex, DoubleType, true), + BoundReference(numOfExprs + colIndex, DoubleType, true)))) + // these should not fail to compile due to 64K limit + GenerateUnsafeProjection.generate(exprs, true) + GenerateMutableProjection.generate(exprs, true) + } } case class HugeCodeIntExpression(value: Int) extends Expression {