[SPARK-22600][SQL] Fix 64kb limit for deeply nested expressions under wholestage codegen

## What changes were proposed in this pull request?

SPARK-22543 fixes the 64kb compile error for deeply nested expression for non-wholestage codegen. This PR extends it to support wholestage codegen.

This patch brings some util methods in to extract necessary parameters for an expression if it is split to a function.

The util methods are put in object `ExpressionCodegen` under `codegen`. The main entry is `getExpressionInputParams` which returns all necessary parameters to evaluate the given expression in a split function.

This util methods can be used to split expressions too. This is a TODO item later.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19813 from viirya/reduce-expr-code-for-wholestage.
This commit is contained in:
Liang-Chi Hsieh 2017-12-13 10:40:05 +08:00 committed by Wenchen Fan
parent 4117786a87
commit c7d0148615
6 changed files with 585 additions and 13 deletions

View file

@ -105,6 +105,12 @@ abstract class Expression extends TreeNode[Expression] {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode("", isNull, value))
eval.isNull = if (this.nullable) eval.isNull else "false"
// Records current input row and variables of this expression.
eval.inputRow = ctx.INPUT_ROW
eval.inputVars = findInputVars(ctx, eval)
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
@ -115,9 +121,29 @@ abstract class Expression extends TreeNode[Expression] {
}
}
/**
* Returns the input variables to this expression.
*/
private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = {
if (ctx.currentVars != null) {
this.collect {
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
ExprInputVar(exprCode = ctx.currentVars(ordinal),
dataType = b.dataType, nullable = b.nullable)
}
} else {
Seq.empty
}
}
/**
* In order to prevent 64kb compile error, reducing the size of generated codes by
* separating it into a function if the size exceeds a threshold.
*/
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this)
if (eval.code.trim.length > 1024 && funcParams.isDefined) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val globalIsNull = ctx.freshName("globalIsNull")
ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
@ -132,9 +158,12 @@ abstract class Expression extends TreeNode[Expression] {
val newValue = ctx.freshName("value")
val funcName = ctx.freshName(nodeName)
val callParams = funcParams.map(_._1.mkString(", ")).get
val declParams = funcParams.map(_._2.mkString(", ")).get
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
|private $javaType $funcName($declParams) {
| ${eval.code.trim}
| $setIsNull
| return ${eval.value};
@ -142,7 +171,7 @@ abstract class Expression extends TreeNode[Expression] {
""".stripMargin)
eval.value = newValue
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.code = s"$javaType $newValue = $funcFullName($callParams);"
}
}

View file

@ -55,8 +55,24 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
* @param inputRow A term that holds the input row name when generating this code.
* @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code.
*/
case class ExprCode(var code: String, var isNull: String, var value: String)
case class ExprCode(
var code: String,
var isNull: String,
var value: String,
var inputRow: String = null,
var inputVars: Seq[ExprInputVar] = Seq.empty)
/**
* Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]].
*
* @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable.
* @param dataType The data type of the input variable.
* @param nullable Whether the input variable can be null or not.
*/
case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean)
/**
* State used for subexpression elimination.
@ -1012,16 +1028,25 @@ class CodegenContext {
commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("evalExpr")
val isNull = s"${fnName}IsNull"
val isNull = if (expr.nullable) {
s"${fnName}IsNull"
} else {
""
}
val value = s"${fnName}Value"
// Generate the code for this expression tree and wrap it in a function.
val eval = expr.genCode(this)
val assignIsNull = if (expr.nullable) {
s"$isNull = ${eval.isNull};"
} else {
""
}
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code.trim}
| $isNull = ${eval.isNull};
| $assignIsNull
| $value = ${eval.value};
|}
""".stripMargin
@ -1039,12 +1064,17 @@ class CodegenContext {
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.
addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;")
addMutableState(javaType(expr.dataType), value,
s"$value = ${defaultValue(expr.dataType)};")
if (expr.nullable) {
addMutableState(JAVA_BOOLEAN, isNull)
}
addMutableState(javaType(expr.dataType), value)
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
val state = if (expr.nullable) {
SubExprEliminationState(isNull, value)
} else {
SubExprEliminationState("false", value)
}
e.foreach(subExprEliminationExprs.put(_, state))
}
}

View file

@ -0,0 +1,269 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.codegen
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DataType
/**
* Defines util methods used in expression code generation.
*/
object ExpressionCodegen {
/**
* Given an expression, returns the all necessary parameters to evaluate it, so the generated
* code of this expression can be split in a function.
* The 1st string in returned tuple is the parameter strings used to call the function.
* The 2nd string in returned tuple is the parameter strings used to declare the function.
*
* Returns `None` if it can't produce valid parameters.
*
* Params to include:
* 1. Evaluated columns referred by this, children or deferred expressions.
* 2. Rows referred by this, children or deferred expressions.
* 3. Eliminated subexpressions referred by children expressions.
*/
def getExpressionInputParams(
ctx: CodegenContext,
expr: Expression): Option[(Seq[String], Seq[String])] = {
val subExprs = getSubExprInChildren(ctx, expr)
val subExprCodes = getSubExprCodes(ctx, subExprs)
val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) =>
ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable)
}
val paramsFromSubExprs = prepareFunctionParams(ctx, subVars)
val inputVars = getInputVarsForChildren(ctx, expr)
val paramsFromColumns = prepareFunctionParams(ctx, inputVars)
val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr)
val paramsFromRows = inputRows.distinct.filter(_ != null).map { row =>
(row, s"InternalRow $row")
}
val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length
// Maximum allowed parameter number for Java's method descriptor.
if (paramsLength > 255) {
None
} else {
val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip
val callParams = allParams._1.distinct
val declParams = allParams._2.distinct
Some((callParams, declParams))
}
}
/**
* Returns the eliminated subexpressions in the children expressions.
*/
def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = {
expr.children.flatMap { child =>
child.collect {
case e if ctx.subExprEliminationExprs.contains(e) => e
}
}.distinct
}
/**
* A small helper function to return `ExprCode`s that represent subexpressions.
*/
def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = {
subExprs.map { subExpr =>
val state = ctx.subExprEliminationExprs(subExpr)
ExprCode(code = "", value = state.value, isNull = state.isNull)
}
}
/**
* Retrieves previous input rows referred by children and deferred expressions.
*/
def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = {
expr.children.flatMap(getInputRows(ctx, _)).distinct
}
/**
* Given a child expression, retrieves previous input rows referred by it or deferred expressions
* which are needed to evaluate it.
*/
def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = {
child.flatMap {
// An expression directly evaluates on current input row.
case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
ctx.currentVars(ordinal) == null =>
Seq(ctx.INPUT_ROW)
// An expression which is not evaluated yet. Tracks down to find input rows.
case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) =>
trackDownRow(ctx, ctx.currentVars(ordinal))
case _ => Seq.empty
}.distinct
}
/**
* Tracks down input rows referred by the generated code snippet.
*/
def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = {
val exprCodes = mutable.Queue[ExprCode](exprCode)
val inputRows = mutable.ArrayBuffer.empty[String]
while (exprCodes.nonEmpty) {
val curExprCode = exprCodes.dequeue()
if (curExprCode.inputRow != null) {
inputRows += curExprCode.inputRow
}
curExprCode.inputVars.foreach { inputVar =>
if (!isEvaluated(inputVar.exprCode)) {
exprCodes.enqueue(inputVar.exprCode)
}
}
}
inputRows
}
/**
* Retrieves previously evaluated columns referred by children and deferred expressions.
* Returned tuple contains the list of expressions and the list of generated codes.
*/
def getInputVarsForChildren(
ctx: CodegenContext,
expr: Expression): Seq[ExprInputVar] = {
expr.children.flatMap(getInputVars(ctx, _)).distinct
}
/**
* Given a child expression, retrieves previously evaluated columns referred by it or
* deferred expressions which are needed to evaluate it.
*/
def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = {
if (ctx.currentVars == null) {
return Seq.empty
}
child.flatMap {
// An evaluated variable.
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null &&
isEvaluated(ctx.currentVars(ordinal)) =>
Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable))
// An input variable which is not evaluated yet. Tracks down to find any evaluated variables
// in the expression path.
// E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to
// "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so
// to include them into parameters, if not, we track down further.
case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
trackDownVar(ctx, ctx.currentVars(ordinal))
case _ => Seq.empty
}.distinct
}
/**
* Tracks down previously evaluated columns referred by the generated code snippet.
*/
def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = {
val exprCodes = mutable.Queue[ExprCode](exprCode)
val inputVars = mutable.ArrayBuffer.empty[ExprInputVar]
while (exprCodes.nonEmpty) {
exprCodes.dequeue().inputVars.foreach { inputVar =>
if (isEvaluated(inputVar.exprCode)) {
inputVars += inputVar
} else {
exprCodes.enqueue(inputVar.exprCode)
}
}
}
inputVars
}
/**
* Helper function to calculate the size of an expression as function parameter.
*/
def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = {
(if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match {
case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2
case _ => 1
}
}
/**
* In Java, a method descriptor is valid only if it represents method parameters with a total
* length of 255 or less. `this` contributes one unit and a parameter of type long or double
* contributes two units.
*/
def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = {
// Initial value is 1 for `this`.
1 + inputs.map(calculateParamLength(ctx, _)).sum
}
/**
* Given the lists of input attributes and variables to this expression, returns the strings of
* funtion parameters. The first is the variable names used to call the function, the second is
* the parameters used to declare the function in generated code.
*/
def prepareFunctionParams(
ctx: CodegenContext,
inputVars: Seq[ExprInputVar]): Seq[(String, String)] = {
inputVars.flatMap { inputVar =>
val params = mutable.ArrayBuffer.empty[(String, String)]
val ev = inputVar.exprCode
// Only include the expression value if it is not a literal.
if (!isLiteral(ev)) {
val argType = ctx.javaType(inputVar.dataType)
params += ((ev.value, s"$argType ${ev.value}"))
}
// If it is a nullable expression and `isNull` is not a literal.
if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") {
params += ((ev.isNull, s"boolean ${ev.isNull}"))
}
params
}.distinct
}
/**
* Only applied to the `ExprCode` in `ctx.currentVars`.
* Returns true if this value is a literal.
*/
def isLiteral(exprCode: ExprCode): Boolean = {
assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.")
if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") {
true
} else {
// The valid characters for the first character of a Java variable is [a-zA-Z_$].
exprCode.value.head match {
case v if v >= 'a' && v <= 'z' => false
case v if v >= 'A' && v <= 'Z' => false
case '_' | '$' => false
case _ => true
}
}
}
/**
* Only applied to the `ExprCode` in `ctx.currentVars`.
* The code is emptied after evaluation.
*/
def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == ""
}

View file

@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.IntegerType
class ExpressionCodegenSuite extends SparkFunSuite {
test("Returns eliminated subexpressions for expression") {
val ctx = new CodegenContext()
val subExpr = Add(Literal(1), Literal(2))
val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4)))
ctx.generateExpressions(exprs, doSubexpressionElimination = true)
val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0))
assert(subexpressions.length == 1 && subexpressions(0) == subExpr)
}
test("Gets parameters for subexpressions") {
val ctx = new CodegenContext()
val subExprs = Seq(
Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable
Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable
ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1"))
ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2"))
val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs)
val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) =>
ExprInputVar(exprCode, expr.dataType, expr.nullable)
}
val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars)
assert(params.length == 3)
assert(params(0) == Tuple2("value1", "int value1"))
assert(params(1) == Tuple2("value2", "int value2"))
assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
}
test("Returns input variables for expression: current variables") {
val ctx = new CodegenContext()
val currentVars = Seq(
ExprCode("", isNull = "false", value = "value1"), // evaluated
ExprCode("", isNull = "isNull2", value = "value2"), // evaluated
ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated
ctx.currentVars = currentVars
ctx.INPUT_ROW = null
val expr = If(Literal(false),
Add(BoundReference(0, IntegerType, nullable = false),
BoundReference(1, IntegerType, nullable = true)),
BoundReference(2, IntegerType, nullable = true))
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
// Only two evaluated variables included.
assert(inputVars.length == 2)
assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false)
assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true)
assert(inputVars(0).exprCode == currentVars(0))
assert(inputVars(1).exprCode == currentVars(1))
val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
assert(params.length == 3)
assert(params(0) == Tuple2("value1", "int value1"))
assert(params(1) == Tuple2("value2", "int value2"))
assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
}
test("Returns input variables for expression: deferred variables") {
val ctx = new CodegenContext()
// The referred column is not evaluated yet. But it depends on an evaluated column from
// other operator.
val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1"))
// currentVars(0) depends on this evaluated column.
currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"),
dataType = IntegerType, nullable = true))
ctx.currentVars = currentVars
ctx.INPUT_ROW = null
val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false))
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
assert(inputVars.length == 1)
assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true)
val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
assert(params.length == 2)
assert(params(0) == Tuple2("value2", "int value2"))
assert(params(1) == Tuple2("isNull2", "boolean isNull2"))
}
test("Returns input rows for expression") {
val ctx = new CodegenContext()
ctx.currentVars = null
ctx.INPUT_ROW = "i"
val expr = Add(BoundReference(0, IntegerType, nullable = false),
BoundReference(1, IntegerType, nullable = true))
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
assert(inputRows.length == 1)
assert(inputRows(0) == "i")
}
test("Returns input rows for expression: deferred expression") {
val ctx = new CodegenContext()
// The referred column is not evaluated yet. But it depends on an input row from
// other operator.
val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1"))
currentVars(0).inputRow = "inputadaptor_row1"
ctx.currentVars = currentVars
ctx.INPUT_ROW = null
val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false))
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
assert(inputRows.length == 1)
assert(inputRows(0) == "inputadaptor_row1")
}
test("Returns both input rows and variables for expression") {
val ctx = new CodegenContext()
// 5 input variables in currentVars:
// 1 evaluated variable (value1).
// 3 not evaluated variables.
// value2 depends on an evaluated column from other operator.
// value3 depends on an input row from other operator.
// value4 depends on a not evaluated yet column from other operator.
// 1 null indicating to use input row "i".
val currentVars = Seq(
ExprCode("", isNull = "false", value = "value1"),
ExprCode("fake code;", isNull = "isNull2", value = "value2"),
ExprCode("fake code;", isNull = "isNull3", value = "value3"),
ExprCode("fake code;", isNull = "isNull4", value = "value4"),
null)
// value2 depends on this evaluated column.
currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"),
dataType = IntegerType, nullable = true))
// value3 depends on an input row "inputadaptor_row1".
currentVars(2).inputRow = "inputadaptor_row1"
// value4 depends on another not evaluated yet column.
currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;",
isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true))
ctx.currentVars = currentVars
ctx.INPUT_ROW = "i"
// expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] }
val expr = If(Literal(false),
Add(BoundReference(0, IntegerType, nullable = false),
BoundReference(1, IntegerType, nullable = true)),
Add(Add(BoundReference(2, IntegerType, nullable = true),
BoundReference(3, IntegerType, nullable = true)),
BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i".
// input rows: "i", "inputadaptor_row1".
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
assert(inputRows.length == 2)
assert(inputRows(0) == "inputadaptor_row1")
assert(inputRows(1) == "i")
// input variables: value1 and value5
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
assert(inputVars.length == 2)
// value1 has inlined isNull "false", so don't need to include it in the params.
val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
assert(inputVarParams.length == 3)
assert(inputVarParams(0) == Tuple2("value1", "int value1"))
assert(inputVarParams(1) == Tuple2("value5", "int value5"))
assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5"))
}
test("isLiteral: literals") {
val literals = Seq(
ExprCode("", "", "true"),
ExprCode("", "", "false"),
ExprCode("", "", "1"),
ExprCode("", "", "-1"),
ExprCode("", "", "1L"),
ExprCode("", "", "-1L"),
ExprCode("", "", "1.0f"),
ExprCode("", "", "-1.0f"),
ExprCode("", "", "0.1f"),
ExprCode("", "", "-0.1f"),
ExprCode("", "", """"string""""),
ExprCode("", "", "(byte)-1"),
ExprCode("", "", "(short)-1"),
ExprCode("", "", "null"))
literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true))
}
test("isLiteral: non literals") {
val variables = Seq(
ExprCode("", "", "var1"),
ExprCode("", "", "_var2"),
ExprCode("", "", "$var3"),
ExprCode("", "", "v1a2r3"),
ExprCode("", "", "_1v2a3r"),
ExprCode("", "", "$1v2a3r"))
variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false))
}
}

View file

@ -108,7 +108,10 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|}""".stripMargin)
ctx.currentVars = null
// `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it.
// So making it as global variable.
val rowidx = ctx.freshName("rowIdx")
ctx.addMutableState(ctx.JAVA_INT, rowidx)
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
@ -128,7 +131,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
| int $numRows = $batch.numRows();
| int $localEnd = $numRows - $idx;
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| int $rowidx = $idx + $localIdx;
| $rowidx = $idx + $localIdx;
| ${consume(ctx, columnsBatchInput).trim}
| $shouldStop
| }

View file

@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@ -236,4 +237,24 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
}
}
}
test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") {
import testImplicits._
withTempPath { dir =>
val path = dir.getCanonicalPath
val df = Seq(("abc", 1)).toDF("key", "int")
df.write.parquet(path)
var strExpr: Expression = col("key").expr
for (_ <- 1 to 150) {
strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8"))
}
val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*)
val plan = df2.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
df2.collect()
}
}
}