[SPARK-22701][SQL] add ctx.splitExpressionsWithCurrentInputs

## What changes were proposed in this pull request?

This pattern appears many times in the codebase:
```
if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
  exprs.mkString("\n")
} else {
  ctx.splitExpressions(...)
}
```

This PR adds a `ctx.splitExpressionsWithCurrentInputs` for this pattern

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #19895 from cloud-fan/splitExpression.
This commit is contained in:
Wenchen Fan 2017-12-05 10:15:15 -08:00 committed by gatorsmile
parent 03fdc92e42
commit ced6ccf0d6
12 changed files with 179 additions and 206 deletions

View file

@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression {
}
"""
}
val codes = ctx.splitExpressions(evalChildren.map(updateEval))
val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
"""
}
val codes = ctx.splitExpressions(evalChildren.map(updateEval))
val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};

View file

@ -781,29 +781,26 @@ class CodegenContext {
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
* Note that we will extract the current inputs of this context and pass them to the generated
* functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole
* stage codegen path. Whole stage codegen path is not supported yet.
*
* @param expressions the codes to evaluate expressions.
*/
def splitExpressions(expressions: Seq[String]): String = {
splitExpressions(expressions, funcName = "apply", extraArguments = Nil)
}
/**
* Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name
* and extra arguments.
* Note that different from `splitExpressions`, we will extract the current inputs of this
* context and pass them to the generated functions. The input is `INPUT_ROW` for normal codegen
* path, and `currentVars` for whole stage codegen path. Whole stage codegen path is not
* supported yet.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
* @param extraArguments the list of (type, name) of the arguments of the split function
* except for ctx.INPUT_ROW
*/
def splitExpressions(
* @param extraArguments the list of (type, name) of the arguments of the split function,
* except for the current inputs like `ctx.INPUT_ROW`.
* @param returnType the return type of the split function.
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
* @param foldFunctions folds the split function calls.
*/
def splitExpressionsWithCurrentInputs(
expressions: Seq[String],
funcName: String,
extraArguments: Seq[(String, String)]): String = {
funcName: String = "apply",
extraArguments: Seq[(String, String)] = Nil,
returnType: String = "void",
makeSplitFunction: String => String = identity,
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
// TODO: support whole stage codegen
if (INPUT_ROW == null || currentVars != null) {
expressions.mkString("\n")
@ -811,13 +808,18 @@ class CodegenContext {
splitExpressions(
expressions,
funcName,
arguments = ("InternalRow", INPUT_ROW) +: extraArguments)
("InternalRow", INPUT_ROW) +: extraArguments,
returnType,
makeSplitFunction,
foldFunctions)
}
}
/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
* instead, because classes have a constant pool limit of 65,536 named values.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.

View file

@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
val allProjections = ctx.splitExpressions(projectionCodes)
val allUpdates = ctx.splitExpressions(updates)
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes)
val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {

View file

@ -159,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""
}
val allExpressions = ctx.splitExpressions(expressionCodes)
val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes)
val codeBody = s"""
public java.lang.Object generate(Object[] references) {

View file

@ -108,7 +108,7 @@ private [sql] object GenArrayData {
}
"""
}
val assignmentString = ctx.splitExpressions(
val assignmentString = ctx.splitExpressionsWithCurrentInputs(
expressions = assignments,
funcName = "apply",
extraArguments = ("Object[]", arrayDataName) :: Nil)
@ -139,7 +139,7 @@ private [sql] object GenArrayData {
}
"""
}
val assignmentString = ctx.splitExpressions(
val assignmentString = ctx.splitExpressionsWithCurrentInputs(
expressions = assignments,
funcName = "apply",
extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
ctx.addMutableState("Object[]", values, s"$values = null;")
val valuesCode = ctx.splitExpressions(
val valuesCode = ctx.splitExpressionsWithCurrentInputs(
valExprs.zipWithIndex.map { case (e, i) =>
val eval = e.genCode(ctx)
s"""

View file

@ -219,57 +219,51 @@ case class CaseWhen(
val allConditions = cases ++ elseCode
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
allConditions.mkString("\n")
} else {
// This generates code like:
// conditionMet = caseWhen_1(i);
// if(conditionMet) {
// continue;
// }
// conditionMet = caseWhen_2(i);
// if(conditionMet) {
// continue;
// }
// ...
// and the declared methods are:
// private boolean caseWhen_1234() {
// boolean conditionMet = false;
// do {
// // here the evaluation of the conditions
// } while (false);
// return conditionMet;
// }
ctx.splitExpressions(allConditions, "caseWhen",
("InternalRow", ctx.INPUT_ROW) :: Nil,
returnType = ctx.JAVA_BOOLEAN,
makeSplitFunction = {
func =>
s"""
${ctx.JAVA_BOOLEAN} $conditionMet = false;
do {
$func
} while (false);
return $conditionMet;
"""
},
foldFunctions = { funcCalls =>
funcCalls.map { funcCall =>
s"""
$conditionMet = $funcCall;
if ($conditionMet) {
continue;
}"""
}.mkString
})
}
// This generates code like:
// conditionMet = caseWhen_1(i);
// if(conditionMet) {
// continue;
// }
// conditionMet = caseWhen_2(i);
// if(conditionMet) {
// continue;
// }
// ...
// and the declared methods are:
// private boolean caseWhen_1234() {
// boolean conditionMet = false;
// do {
// // here the evaluation of the conditions
// } while (false);
// return conditionMet;
// }
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
returnType = ctx.JAVA_BOOLEAN,
makeSplitFunction = func =>
s"""
|${ctx.JAVA_BOOLEAN} $conditionMet = false;
|do {
| $func
|} while (false);
|return $conditionMet;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$conditionMet = $funcCall;
|if ($conditionMet) {
| continue;
|}
""".stripMargin
}.mkString)
ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
${ctx.JAVA_BOOLEAN} $conditionMet = false;
do {
$code
$codes
} while (false);""")
}
}

View file

@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row =>
val fields = Seq.tabulate(numFields) { col =>
val index = row * numFields + col
if (index < values.length) values(index) else Literal(null, dataTypes(col))

View file

@ -279,21 +279,17 @@ abstract class HashExpression[E] extends Expression {
}
val hashResultType = ctx.javaType(dataType)
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
childrenHash.mkString("\n")
} else {
ctx.splitExpressions(
expressions = childrenHash,
funcName = "computeHash",
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
extraArguments = Seq(hashResultType -> ev.value),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
@ -652,22 +648,19 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
""".stripMargin
}
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
childrenHash.mkString("\n")
} else {
ctx.splitExpressions(
expressions = childrenHash,
funcName = "computeHash",
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value),
returnType = ctx.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
extraArguments = Seq(ctx.JAVA_INT -> ev.value),
returnType = ctx.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""

View file

@ -87,37 +87,32 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|}
""".stripMargin
}
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
evals.mkString("\n")
} else {
ctx.splitExpressions(evals, "coalesce",
("InternalRow", ctx.INPUT_ROW) :: Nil,
makeSplitFunction = {
func =>
s"""
|do {
| $func
|} while (false);
""".stripMargin
},
foldFunctions = { funcCalls =>
funcCalls.map { funcCall =>
s"""
|$funcCall;
|if (!${ev.isNull}) {
| continue;
|}
""".stripMargin
}.mkString
})
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
makeSplitFunction = func =>
s"""
|do {
| $func
|} while (false);
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (!${ev.isNull}) {
| continue;
|}
""".stripMargin
}.mkString)
ev.copy(code =
s"""
|${ev.isNull} = true;
|${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $code
| $codes
|} while (false);
""".stripMargin)
}
@ -415,39 +410,32 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}
}
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
evals.mkString("\n")
} else {
ctx.splitExpressions(
expressions = evals,
funcName = "atLeastNNonNulls",
arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil,
returnType = ctx.JAVA_INT,
makeSplitFunction = { body =>
s"""
|do {
| $body
|} while (false);
|return $nonnull;
""".stripMargin
},
foldFunctions = { funcCalls =>
funcCalls.map(funcCall =>
s"""
|$nonnull = $funcCall;
|if ($nonnull >= $n) {
| continue;
|}
""".stripMargin).mkString("\n")
}
)
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "atLeastNNonNulls",
extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
returnType = ctx.JAVA_INT,
makeSplitFunction = body =>
s"""
|do {
| $body
|} while (false);
|return $nonnull;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$nonnull = $funcCall;
|if ($nonnull >= $n) {
| continue;
|}
""".stripMargin
}.mkString)
ev.copy(code =
s"""
|${ctx.JAVA_INT} $nonnull = 0;
|do {
| $code
| $codes
|} while (false);
|${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
""".stripMargin, isNull = "false")

View file

@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
"""
}
}
val argCode = ctx.splitExpressions(argCodes)
val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes)
(argCode, argValues.mkString(", "), resultIsNull)
}
@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
"""
}
val childrenCode = ctx.splitExpressions(childrenCodes)
val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
val code = s"""
@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
${javaBeanInstance}.$setterMethod(${fieldGen.value});
"""
}
val initializeCode = ctx.splitExpressions(initialize.toSeq)
val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq)
val code = s"""
${instanceGen.code}

View file

@ -253,31 +253,26 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
| continue;
|}
""".stripMargin)
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
listCode.mkString("\n")
} else {
ctx.splitExpressions(
expressions = listCode,
funcName = "valueIn",
arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil,
makeSplitFunction = { body =>
s"""
|do {
| $body
|} while (false);
""".stripMargin
},
foldFunctions = { funcCalls =>
funcCalls.map(funcCall =>
s"""
|$funcCall;
|if (${ev.value}) {
| continue;
|}
""".stripMargin).mkString("\n")
}
)
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
extraArguments = (javaDataType, valueArg) :: Nil,
makeSplitFunction = body =>
s"""
|do {
| $body
|} while (false);
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (${ev.value}) {
| continue;
|}
""".stripMargin
}.mkString("\n"))
ev.copy(code =
s"""
|${valueGen.code}
@ -286,7 +281,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|if (!${ev.isNull}) {
| $javaDataType $valueArg = ${valueGen.value};
| do {
| $code
| $codes
| } while (false);
|}
""".stripMargin)

View file

@ -73,7 +73,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
val codes = ctx.splitExpressions(
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = ("UTF8String[]", args) :: Nil)
@ -152,7 +152,7 @@ case class ConcatWs(children: Seq[Expression])
""
}
}
val codes = ctx.splitExpressions(
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcatWs",
extraArguments = ("UTF8String[]", args) :: Nil)
@ -200,31 +200,32 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
val codes = ctx.splitExpressions(evals.map(_.code))
val varargCounts = ctx.splitExpressions(
val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
funcName = "varargCountsConcatWs",
arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
int $varargNum = 0;
$body
return $varargNum;
""",
foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
val varargBuilds = ctx.splitExpressions(
|int $varargNum = 0;
|$body
|return $varargNum;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$varargNum += $funcCall;").mkString("\n"))
val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
expressions = varargBuild,
funcName = "varargBuildsConcatWs",
arguments =
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
$body
return $idxInVararg;
""",
foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
|$body
|return $idxInVararg;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n"))
ev.copy(
s"""
$codes
@ -1380,7 +1381,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
$argList[$index] = $value;
"""
}
val argListCodes = ctx.splitExpressions(
val argListCodes = ctx.splitExpressionsWithCurrentInputs(
expressions = argListCode,
funcName = "valueFormatString",
extraArguments = ("Object[]", argList) :: Nil)