[SPARK-22600][SQL] Fix 64kb limit for deeply nested expressions under wholestage codegen
## What changes were proposed in this pull request? SPARK-22543 fixes the 64kb compile error for deeply nested expression for non-wholestage codegen. This PR extends it to support wholestage codegen. This patch brings some util methods in to extract necessary parameters for an expression if it is split to a function. The util methods are put in object `ExpressionCodegen` under `codegen`. The main entry is `getExpressionInputParams` which returns all necessary parameters to evaluate the given expression in a split function. This util methods can be used to split expressions too. This is a TODO item later. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19813 from viirya/reduce-expr-code-for-wholestage.
This commit is contained in:
parent
4117786a87
commit
c7d0148615
|
@ -105,6 +105,12 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
val isNull = ctx.freshName("isNull")
|
||||
val value = ctx.freshName("value")
|
||||
val eval = doGenCode(ctx, ExprCode("", isNull, value))
|
||||
eval.isNull = if (this.nullable) eval.isNull else "false"
|
||||
|
||||
// Records current input row and variables of this expression.
|
||||
eval.inputRow = ctx.INPUT_ROW
|
||||
eval.inputVars = findInputVars(ctx, eval)
|
||||
|
||||
reduceCodeSize(ctx, eval)
|
||||
if (eval.code.nonEmpty) {
|
||||
// Add `this` in the comment.
|
||||
|
@ -115,9 +121,29 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the input variables to this expression.
|
||||
*/
|
||||
private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = {
|
||||
if (ctx.currentVars != null) {
|
||||
this.collect {
|
||||
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
|
||||
ExprInputVar(exprCode = ctx.currentVars(ordinal),
|
||||
dataType = b.dataType, nullable = b.nullable)
|
||||
}
|
||||
} else {
|
||||
Seq.empty
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In order to prevent 64kb compile error, reducing the size of generated codes by
|
||||
* separating it into a function if the size exceeds a threshold.
|
||||
*/
|
||||
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
|
||||
// TODO: support whole stage codegen too
|
||||
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
|
||||
lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this)
|
||||
|
||||
if (eval.code.trim.length > 1024 && funcParams.isDefined) {
|
||||
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
|
||||
val globalIsNull = ctx.freshName("globalIsNull")
|
||||
ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
|
||||
|
@ -132,9 +158,12 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
val newValue = ctx.freshName("value")
|
||||
|
||||
val funcName = ctx.freshName(nodeName)
|
||||
val callParams = funcParams.map(_._1.mkString(", ")).get
|
||||
val declParams = funcParams.map(_._2.mkString(", ")).get
|
||||
|
||||
val funcFullName = ctx.addNewFunction(funcName,
|
||||
s"""
|
||||
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
|
||||
|private $javaType $funcName($declParams) {
|
||||
| ${eval.code.trim}
|
||||
| $setIsNull
|
||||
| return ${eval.value};
|
||||
|
@ -142,7 +171,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
""".stripMargin)
|
||||
|
||||
eval.value = newValue
|
||||
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
|
||||
eval.code = s"$javaType $newValue = $funcFullName($callParams);"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -55,8 +55,24 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
|
|||
* to null.
|
||||
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
|
||||
* valid if `isNull` is set to `true`.
|
||||
* @param inputRow A term that holds the input row name when generating this code.
|
||||
* @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code.
|
||||
*/
|
||||
case class ExprCode(var code: String, var isNull: String, var value: String)
|
||||
case class ExprCode(
|
||||
var code: String,
|
||||
var isNull: String,
|
||||
var value: String,
|
||||
var inputRow: String = null,
|
||||
var inputVars: Seq[ExprInputVar] = Seq.empty)
|
||||
|
||||
/**
|
||||
* Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]].
|
||||
*
|
||||
* @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable.
|
||||
* @param dataType The data type of the input variable.
|
||||
* @param nullable Whether the input variable can be null or not.
|
||||
*/
|
||||
case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean)
|
||||
|
||||
/**
|
||||
* State used for subexpression elimination.
|
||||
|
@ -1012,16 +1028,25 @@ class CodegenContext {
|
|||
commonExprs.foreach { e =>
|
||||
val expr = e.head
|
||||
val fnName = freshName("evalExpr")
|
||||
val isNull = s"${fnName}IsNull"
|
||||
val isNull = if (expr.nullable) {
|
||||
s"${fnName}IsNull"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
val value = s"${fnName}Value"
|
||||
|
||||
// Generate the code for this expression tree and wrap it in a function.
|
||||
val eval = expr.genCode(this)
|
||||
val assignIsNull = if (expr.nullable) {
|
||||
s"$isNull = ${eval.isNull};"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
val fn =
|
||||
s"""
|
||||
|private void $fnName(InternalRow $INPUT_ROW) {
|
||||
| ${eval.code.trim}
|
||||
| $isNull = ${eval.isNull};
|
||||
| $assignIsNull
|
||||
| $value = ${eval.value};
|
||||
|}
|
||||
""".stripMargin
|
||||
|
@ -1039,12 +1064,17 @@ class CodegenContext {
|
|||
// 2. Less code.
|
||||
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
|
||||
// at least two nodes) as the cost of doing it is expected to be low.
|
||||
addMutableState(JAVA_BOOLEAN, isNull, s"$isNull = false;")
|
||||
addMutableState(javaType(expr.dataType), value,
|
||||
s"$value = ${defaultValue(expr.dataType)};")
|
||||
if (expr.nullable) {
|
||||
addMutableState(JAVA_BOOLEAN, isNull)
|
||||
}
|
||||
addMutableState(javaType(expr.dataType), value)
|
||||
|
||||
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
|
||||
val state = SubExprEliminationState(isNull, value)
|
||||
val state = if (expr.nullable) {
|
||||
SubExprEliminationState(isNull, value)
|
||||
} else {
|
||||
SubExprEliminationState("false", value)
|
||||
}
|
||||
e.foreach(subExprEliminationExprs.put(_, state))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,269 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
/**
|
||||
* Defines util methods used in expression code generation.
|
||||
*/
|
||||
object ExpressionCodegen {
|
||||
|
||||
/**
|
||||
* Given an expression, returns the all necessary parameters to evaluate it, so the generated
|
||||
* code of this expression can be split in a function.
|
||||
* The 1st string in returned tuple is the parameter strings used to call the function.
|
||||
* The 2nd string in returned tuple is the parameter strings used to declare the function.
|
||||
*
|
||||
* Returns `None` if it can't produce valid parameters.
|
||||
*
|
||||
* Params to include:
|
||||
* 1. Evaluated columns referred by this, children or deferred expressions.
|
||||
* 2. Rows referred by this, children or deferred expressions.
|
||||
* 3. Eliminated subexpressions referred by children expressions.
|
||||
*/
|
||||
def getExpressionInputParams(
|
||||
ctx: CodegenContext,
|
||||
expr: Expression): Option[(Seq[String], Seq[String])] = {
|
||||
val subExprs = getSubExprInChildren(ctx, expr)
|
||||
val subExprCodes = getSubExprCodes(ctx, subExprs)
|
||||
val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) =>
|
||||
ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable)
|
||||
}
|
||||
val paramsFromSubExprs = prepareFunctionParams(ctx, subVars)
|
||||
|
||||
val inputVars = getInputVarsForChildren(ctx, expr)
|
||||
val paramsFromColumns = prepareFunctionParams(ctx, inputVars)
|
||||
|
||||
val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr)
|
||||
val paramsFromRows = inputRows.distinct.filter(_ != null).map { row =>
|
||||
(row, s"InternalRow $row")
|
||||
}
|
||||
|
||||
val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length
|
||||
// Maximum allowed parameter number for Java's method descriptor.
|
||||
if (paramsLength > 255) {
|
||||
None
|
||||
} else {
|
||||
val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip
|
||||
val callParams = allParams._1.distinct
|
||||
val declParams = allParams._2.distinct
|
||||
Some((callParams, declParams))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the eliminated subexpressions in the children expressions.
|
||||
*/
|
||||
def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = {
|
||||
expr.children.flatMap { child =>
|
||||
child.collect {
|
||||
case e if ctx.subExprEliminationExprs.contains(e) => e
|
||||
}
|
||||
}.distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* A small helper function to return `ExprCode`s that represent subexpressions.
|
||||
*/
|
||||
def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = {
|
||||
subExprs.map { subExpr =>
|
||||
val state = ctx.subExprEliminationExprs(subExpr)
|
||||
ExprCode(code = "", value = state.value, isNull = state.isNull)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves previous input rows referred by children and deferred expressions.
|
||||
*/
|
||||
def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = {
|
||||
expr.children.flatMap(getInputRows(ctx, _)).distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a child expression, retrieves previous input rows referred by it or deferred expressions
|
||||
* which are needed to evaluate it.
|
||||
*/
|
||||
def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = {
|
||||
child.flatMap {
|
||||
// An expression directly evaluates on current input row.
|
||||
case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
|
||||
ctx.currentVars(ordinal) == null =>
|
||||
Seq(ctx.INPUT_ROW)
|
||||
|
||||
// An expression which is not evaluated yet. Tracks down to find input rows.
|
||||
case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) =>
|
||||
trackDownRow(ctx, ctx.currentVars(ordinal))
|
||||
|
||||
case _ => Seq.empty
|
||||
}.distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks down input rows referred by the generated code snippet.
|
||||
*/
|
||||
def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = {
|
||||
val exprCodes = mutable.Queue[ExprCode](exprCode)
|
||||
val inputRows = mutable.ArrayBuffer.empty[String]
|
||||
|
||||
while (exprCodes.nonEmpty) {
|
||||
val curExprCode = exprCodes.dequeue()
|
||||
if (curExprCode.inputRow != null) {
|
||||
inputRows += curExprCode.inputRow
|
||||
}
|
||||
curExprCode.inputVars.foreach { inputVar =>
|
||||
if (!isEvaluated(inputVar.exprCode)) {
|
||||
exprCodes.enqueue(inputVar.exprCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
inputRows
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves previously evaluated columns referred by children and deferred expressions.
|
||||
* Returned tuple contains the list of expressions and the list of generated codes.
|
||||
*/
|
||||
def getInputVarsForChildren(
|
||||
ctx: CodegenContext,
|
||||
expr: Expression): Seq[ExprInputVar] = {
|
||||
expr.children.flatMap(getInputVars(ctx, _)).distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a child expression, retrieves previously evaluated columns referred by it or
|
||||
* deferred expressions which are needed to evaluate it.
|
||||
*/
|
||||
def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = {
|
||||
if (ctx.currentVars == null) {
|
||||
return Seq.empty
|
||||
}
|
||||
|
||||
child.flatMap {
|
||||
// An evaluated variable.
|
||||
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null &&
|
||||
isEvaluated(ctx.currentVars(ordinal)) =>
|
||||
Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable))
|
||||
|
||||
// An input variable which is not evaluated yet. Tracks down to find any evaluated variables
|
||||
// in the expression path.
|
||||
// E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to
|
||||
// "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so
|
||||
// to include them into parameters, if not, we track down further.
|
||||
case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
|
||||
trackDownVar(ctx, ctx.currentVars(ordinal))
|
||||
|
||||
case _ => Seq.empty
|
||||
}.distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks down previously evaluated columns referred by the generated code snippet.
|
||||
*/
|
||||
def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = {
|
||||
val exprCodes = mutable.Queue[ExprCode](exprCode)
|
||||
val inputVars = mutable.ArrayBuffer.empty[ExprInputVar]
|
||||
|
||||
while (exprCodes.nonEmpty) {
|
||||
exprCodes.dequeue().inputVars.foreach { inputVar =>
|
||||
if (isEvaluated(inputVar.exprCode)) {
|
||||
inputVars += inputVar
|
||||
} else {
|
||||
exprCodes.enqueue(inputVar.exprCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
inputVars
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to calculate the size of an expression as function parameter.
|
||||
*/
|
||||
def calculateParamLength(ctx: CodegenContext, input: ExprInputVar): Int = {
|
||||
(if (input.nullable) 1 else 0) + ctx.javaType(input.dataType) match {
|
||||
case ctx.JAVA_LONG | ctx.JAVA_DOUBLE => 2
|
||||
case _ => 1
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In Java, a method descriptor is valid only if it represents method parameters with a total
|
||||
* length of 255 or less. `this` contributes one unit and a parameter of type long or double
|
||||
* contributes two units.
|
||||
*/
|
||||
def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = {
|
||||
// Initial value is 1 for `this`.
|
||||
1 + inputs.map(calculateParamLength(ctx, _)).sum
|
||||
}
|
||||
|
||||
/**
|
||||
* Given the lists of input attributes and variables to this expression, returns the strings of
|
||||
* funtion parameters. The first is the variable names used to call the function, the second is
|
||||
* the parameters used to declare the function in generated code.
|
||||
*/
|
||||
def prepareFunctionParams(
|
||||
ctx: CodegenContext,
|
||||
inputVars: Seq[ExprInputVar]): Seq[(String, String)] = {
|
||||
inputVars.flatMap { inputVar =>
|
||||
val params = mutable.ArrayBuffer.empty[(String, String)]
|
||||
val ev = inputVar.exprCode
|
||||
|
||||
// Only include the expression value if it is not a literal.
|
||||
if (!isLiteral(ev)) {
|
||||
val argType = ctx.javaType(inputVar.dataType)
|
||||
params += ((ev.value, s"$argType ${ev.value}"))
|
||||
}
|
||||
|
||||
// If it is a nullable expression and `isNull` is not a literal.
|
||||
if (inputVar.nullable && ev.isNull != "true" && ev.isNull != "false") {
|
||||
params += ((ev.isNull, s"boolean ${ev.isNull}"))
|
||||
}
|
||||
|
||||
params
|
||||
}.distinct
|
||||
}
|
||||
|
||||
/**
|
||||
* Only applied to the `ExprCode` in `ctx.currentVars`.
|
||||
* Returns true if this value is a literal.
|
||||
*/
|
||||
def isLiteral(exprCode: ExprCode): Boolean = {
|
||||
assert(exprCode.value.nonEmpty, "ExprCode.value can't be empty string.")
|
||||
|
||||
if (exprCode.value == "true" || exprCode.value == "false" || exprCode.value == "null") {
|
||||
true
|
||||
} else {
|
||||
// The valid characters for the first character of a Java variable is [a-zA-Z_$].
|
||||
exprCode.value.head match {
|
||||
case v if v >= 'a' && v <= 'z' => false
|
||||
case v if v >= 'A' && v <= 'Z' => false
|
||||
case '_' | '$' => false
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Only applied to the `ExprCode` in `ctx.currentVars`.
|
||||
* The code is emptied after evaluation.
|
||||
*/
|
||||
def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == ""
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.IntegerType
|
||||
|
||||
class ExpressionCodegenSuite extends SparkFunSuite {
|
||||
|
||||
test("Returns eliminated subexpressions for expression") {
|
||||
val ctx = new CodegenContext()
|
||||
val subExpr = Add(Literal(1), Literal(2))
|
||||
val exprs = Seq(Add(subExpr, Literal(3)), Add(subExpr, Literal(4)))
|
||||
|
||||
ctx.generateExpressions(exprs, doSubexpressionElimination = true)
|
||||
val subexpressions = ExpressionCodegen.getSubExprInChildren(ctx, exprs(0))
|
||||
assert(subexpressions.length == 1 && subexpressions(0) == subExpr)
|
||||
}
|
||||
|
||||
test("Gets parameters for subexpressions") {
|
||||
val ctx = new CodegenContext()
|
||||
val subExprs = Seq(
|
||||
Add(Literal(1), AttributeReference("a", IntegerType, nullable = false)()), // non-nullable
|
||||
Add(Literal(2), AttributeReference("b", IntegerType, nullable = true)())) // nullable
|
||||
|
||||
ctx.subExprEliminationExprs.put(subExprs(0), SubExprEliminationState("false", "value1"))
|
||||
ctx.subExprEliminationExprs.put(subExprs(1), SubExprEliminationState("isNull2", "value2"))
|
||||
|
||||
val subExprCodes = ExpressionCodegen.getSubExprCodes(ctx, subExprs)
|
||||
val subVars = subExprs.zip(subExprCodes).map { case (expr, exprCode) =>
|
||||
ExprInputVar(exprCode, expr.dataType, expr.nullable)
|
||||
}
|
||||
val params = ExpressionCodegen.prepareFunctionParams(ctx, subVars)
|
||||
assert(params.length == 3)
|
||||
assert(params(0) == Tuple2("value1", "int value1"))
|
||||
assert(params(1) == Tuple2("value2", "int value2"))
|
||||
assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
|
||||
}
|
||||
|
||||
test("Returns input variables for expression: current variables") {
|
||||
val ctx = new CodegenContext()
|
||||
val currentVars = Seq(
|
||||
ExprCode("", isNull = "false", value = "value1"), // evaluated
|
||||
ExprCode("", isNull = "isNull2", value = "value2"), // evaluated
|
||||
ExprCode("fake code;", isNull = "isNull3", value = "value3")) // not evaluated
|
||||
ctx.currentVars = currentVars
|
||||
ctx.INPUT_ROW = null
|
||||
|
||||
val expr = If(Literal(false),
|
||||
Add(BoundReference(0, IntegerType, nullable = false),
|
||||
BoundReference(1, IntegerType, nullable = true)),
|
||||
BoundReference(2, IntegerType, nullable = true))
|
||||
|
||||
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
|
||||
// Only two evaluated variables included.
|
||||
assert(inputVars.length == 2)
|
||||
assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == false)
|
||||
assert(inputVars(1).dataType == IntegerType && inputVars(1).nullable == true)
|
||||
assert(inputVars(0).exprCode == currentVars(0))
|
||||
assert(inputVars(1).exprCode == currentVars(1))
|
||||
|
||||
val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
|
||||
assert(params.length == 3)
|
||||
assert(params(0) == Tuple2("value1", "int value1"))
|
||||
assert(params(1) == Tuple2("value2", "int value2"))
|
||||
assert(params(2) == Tuple2("isNull2", "boolean isNull2"))
|
||||
}
|
||||
|
||||
test("Returns input variables for expression: deferred variables") {
|
||||
val ctx = new CodegenContext()
|
||||
|
||||
// The referred column is not evaluated yet. But it depends on an evaluated column from
|
||||
// other operator.
|
||||
val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1"))
|
||||
|
||||
// currentVars(0) depends on this evaluated column.
|
||||
currentVars(0).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull2", value = "value2"),
|
||||
dataType = IntegerType, nullable = true))
|
||||
ctx.currentVars = currentVars
|
||||
ctx.INPUT_ROW = null
|
||||
|
||||
val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false))
|
||||
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
|
||||
assert(inputVars.length == 1)
|
||||
assert(inputVars(0).dataType == IntegerType && inputVars(0).nullable == true)
|
||||
|
||||
val params = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
|
||||
assert(params.length == 2)
|
||||
assert(params(0) == Tuple2("value2", "int value2"))
|
||||
assert(params(1) == Tuple2("isNull2", "boolean isNull2"))
|
||||
}
|
||||
|
||||
test("Returns input rows for expression") {
|
||||
val ctx = new CodegenContext()
|
||||
ctx.currentVars = null
|
||||
ctx.INPUT_ROW = "i"
|
||||
|
||||
val expr = Add(BoundReference(0, IntegerType, nullable = false),
|
||||
BoundReference(1, IntegerType, nullable = true))
|
||||
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
|
||||
assert(inputRows.length == 1)
|
||||
assert(inputRows(0) == "i")
|
||||
}
|
||||
|
||||
test("Returns input rows for expression: deferred expression") {
|
||||
val ctx = new CodegenContext()
|
||||
|
||||
// The referred column is not evaluated yet. But it depends on an input row from
|
||||
// other operator.
|
||||
val currentVars = Seq(ExprCode("fake code;", isNull = "isNull1", value = "value1"))
|
||||
currentVars(0).inputRow = "inputadaptor_row1"
|
||||
ctx.currentVars = currentVars
|
||||
ctx.INPUT_ROW = null
|
||||
|
||||
val expr = Add(Literal(1), BoundReference(0, IntegerType, nullable = false))
|
||||
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
|
||||
assert(inputRows.length == 1)
|
||||
assert(inputRows(0) == "inputadaptor_row1")
|
||||
}
|
||||
|
||||
test("Returns both input rows and variables for expression") {
|
||||
val ctx = new CodegenContext()
|
||||
// 5 input variables in currentVars:
|
||||
// 1 evaluated variable (value1).
|
||||
// 3 not evaluated variables.
|
||||
// value2 depends on an evaluated column from other operator.
|
||||
// value3 depends on an input row from other operator.
|
||||
// value4 depends on a not evaluated yet column from other operator.
|
||||
// 1 null indicating to use input row "i".
|
||||
val currentVars = Seq(
|
||||
ExprCode("", isNull = "false", value = "value1"),
|
||||
ExprCode("fake code;", isNull = "isNull2", value = "value2"),
|
||||
ExprCode("fake code;", isNull = "isNull3", value = "value3"),
|
||||
ExprCode("fake code;", isNull = "isNull4", value = "value4"),
|
||||
null)
|
||||
// value2 depends on this evaluated column.
|
||||
currentVars(1).inputVars = Seq(ExprInputVar(ExprCode("", isNull = "isNull5", value = "value5"),
|
||||
dataType = IntegerType, nullable = true))
|
||||
// value3 depends on an input row "inputadaptor_row1".
|
||||
currentVars(2).inputRow = "inputadaptor_row1"
|
||||
// value4 depends on another not evaluated yet column.
|
||||
currentVars(3).inputVars = Seq(ExprInputVar(ExprCode("fake code;",
|
||||
isNull = "isNull6", value = "value6"), dataType = IntegerType, nullable = true))
|
||||
ctx.currentVars = currentVars
|
||||
ctx.INPUT_ROW = "i"
|
||||
|
||||
// expr: if (false) { value1 + value2 } else { (value3 + value4) + i[5] }
|
||||
val expr = If(Literal(false),
|
||||
Add(BoundReference(0, IntegerType, nullable = false),
|
||||
BoundReference(1, IntegerType, nullable = true)),
|
||||
Add(Add(BoundReference(2, IntegerType, nullable = true),
|
||||
BoundReference(3, IntegerType, nullable = true)),
|
||||
BoundReference(4, IntegerType, nullable = true))) // this is based on input row "i".
|
||||
|
||||
// input rows: "i", "inputadaptor_row1".
|
||||
val inputRows = ExpressionCodegen.getInputRowsForChildren(ctx, expr)
|
||||
assert(inputRows.length == 2)
|
||||
assert(inputRows(0) == "inputadaptor_row1")
|
||||
assert(inputRows(1) == "i")
|
||||
|
||||
// input variables: value1 and value5
|
||||
val inputVars = ExpressionCodegen.getInputVarsForChildren(ctx, expr)
|
||||
assert(inputVars.length == 2)
|
||||
|
||||
// value1 has inlined isNull "false", so don't need to include it in the params.
|
||||
val inputVarParams = ExpressionCodegen.prepareFunctionParams(ctx, inputVars)
|
||||
assert(inputVarParams.length == 3)
|
||||
assert(inputVarParams(0) == Tuple2("value1", "int value1"))
|
||||
assert(inputVarParams(1) == Tuple2("value5", "int value5"))
|
||||
assert(inputVarParams(2) == Tuple2("isNull5", "boolean isNull5"))
|
||||
}
|
||||
|
||||
test("isLiteral: literals") {
|
||||
val literals = Seq(
|
||||
ExprCode("", "", "true"),
|
||||
ExprCode("", "", "false"),
|
||||
ExprCode("", "", "1"),
|
||||
ExprCode("", "", "-1"),
|
||||
ExprCode("", "", "1L"),
|
||||
ExprCode("", "", "-1L"),
|
||||
ExprCode("", "", "1.0f"),
|
||||
ExprCode("", "", "-1.0f"),
|
||||
ExprCode("", "", "0.1f"),
|
||||
ExprCode("", "", "-0.1f"),
|
||||
ExprCode("", "", """"string""""),
|
||||
ExprCode("", "", "(byte)-1"),
|
||||
ExprCode("", "", "(short)-1"),
|
||||
ExprCode("", "", "null"))
|
||||
|
||||
literals.foreach(l => assert(ExpressionCodegen.isLiteral(l) == true))
|
||||
}
|
||||
|
||||
test("isLiteral: non literals") {
|
||||
val variables = Seq(
|
||||
ExprCode("", "", "var1"),
|
||||
ExprCode("", "", "_var2"),
|
||||
ExprCode("", "", "$var3"),
|
||||
ExprCode("", "", "v1a2r3"),
|
||||
ExprCode("", "", "_1v2a3r"),
|
||||
ExprCode("", "", "$1v2a3r"))
|
||||
|
||||
variables.foreach(v => assert(ExpressionCodegen.isLiteral(v) == false))
|
||||
}
|
||||
}
|
|
@ -108,7 +108,10 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|
|||
|}""".stripMargin)
|
||||
|
||||
ctx.currentVars = null
|
||||
// `rowIdx` isn't in `ctx.currentVars`. If the expressions are split later, we can't track it.
|
||||
// So making it as global variable.
|
||||
val rowidx = ctx.freshName("rowIdx")
|
||||
ctx.addMutableState(ctx.JAVA_INT, rowidx)
|
||||
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
|
||||
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
|
||||
}
|
||||
|
@ -128,7 +131,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|
|||
| int $numRows = $batch.numRows();
|
||||
| int $localEnd = $numRows - $idx;
|
||||
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
|
||||
| int $rowidx = $idx + $localIdx;
|
||||
| $rowidx = $idx + $localIdx;
|
||||
| ${consume(ctx, columnsBatchInput).trim}
|
||||
| $shouldStop
|
||||
| }
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
|
||||
import org.apache.spark.sql.{Column, QueryTest, Row, SaveMode}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
|
||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
|
||||
|
@ -236,4 +237,24 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-22551: Fix 64kb limit for deeply nested expressions under wholestage codegen") {
|
||||
import testImplicits._
|
||||
withTempPath { dir =>
|
||||
val path = dir.getCanonicalPath
|
||||
val df = Seq(("abc", 1)).toDF("key", "int")
|
||||
df.write.parquet(path)
|
||||
|
||||
var strExpr: Expression = col("key").expr
|
||||
for (_ <- 1 to 150) {
|
||||
strExpr = Decode(Encode(strExpr, Literal("utf-8")), Literal("utf-8"))
|
||||
}
|
||||
val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
|
||||
|
||||
val df2 = spark.read.parquet(path).select(expressions.map(Column(_)): _*)
|
||||
val plan = df2.queryExecution.executedPlan
|
||||
assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
|
||||
df2.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue