[SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions()

## What changes were proposed in this pull request?

Passing global variables to the split method is dangerous, as any mutating to it is ignored and may lead to unexpected behavior.

To prevent this, one approach is to make sure no expression would output global variables: Localizing lifetime of mutable states in expressions.

Another approach is, when calling `ctx.splitExpression`, make sure we don't use children's output as parameter names.

Approach 1 is actually hard to do, as we need to check all expressions and operators that support whole-stage codegen. Approach 2 is easier as the callers of `ctx.splitExpressions` are not too many.

Besides, approach 2 is more flexible, as children's output may be other stuff that can't be parameter name: literal, inlined statement(a + 1), etc.

close https://github.com/apache/spark/pull/19865
close https://github.com/apache/spark/pull/19938

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20021 from cloud-fan/codegen.
This commit is contained in:
Wenchen Fan 2017-12-22 00:21:27 +08:00
parent 4c2efde931
commit 8a0ed5a5ee
5 changed files with 43 additions and 26 deletions

View file

@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull")
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
|if (!${eval.isNull} && ($tmpIsNull ||
|if (!${eval.isNull} && (${ev.isNull} ||
| ${ctx.genGreater(dataType, ev.value, eval.value)})) {
| $tmpIsNull = false;
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
|}
""".stripMargin
@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
|$tmpIsNull = true;
|${ev.isNull} = true;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|$codes
|final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}
@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull")
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
|if (!${eval.isNull} && ($tmpIsNull ||
|if (!${eval.isNull} && (${ev.isNull} ||
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
| $tmpIsNull = false;
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
|}
""".stripMargin
@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
|$tmpIsNull = true;
|${ev.isNull} = true;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|$codes
|final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}

View file

@ -128,7 +128,7 @@ class CodegenContext {
* `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
* `Expression.genCode`.
*/
final var INPUT_ROW = "i"
var INPUT_ROW = "i"
/**
* Holding a list of generated columns as input of current operator, will be used by
@ -146,22 +146,30 @@ class CodegenContext {
* as a member variable
*
* They will be kept as member variables in generated classes like `SpecificProjection`.
*
* Exposed for tests only.
*/
val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
mutable.ArrayBuffer.empty[(String, String)]
/**
* The mapping between mutable state types and corrseponding compacted arrays.
* The keys are java type string. The values are [[MutableStateArrays]] which encapsulates
* the compacted arrays for the mutable states with the same java type.
*
* Exposed for tests only.
*/
val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
mutable.Map.empty[String, MutableStateArrays]
// An array holds the code that will initialize each state
val mutableStateInitCode: mutable.ArrayBuffer[String] =
// Exposed for tests only.
private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
mutable.ArrayBuffer.empty[String]
// Tracks the names of all the mutable states.
private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty
/**
* This class holds a set of names of mutableStateArrays that is used for compacting mutable
* states for a certain type, and holds the next available slot of the current compacted array.
@ -172,7 +180,11 @@ class CodegenContext {
private[this] var currentIndex = 0
private def createNewArray() = arrayNames.append(freshName("mutableStateArray"))
private def createNewArray() = {
val newArrayName = freshName("mutableStateArray")
mutableStateNames += newArrayName
arrayNames.append(newArrayName)
}
def getCurrentIndex: Int = currentIndex
@ -241,6 +253,7 @@ class CodegenContext {
val initCode = initFunc(varName)
inlinedMutableStates += ((javaType, varName))
mutableStateInitCode += initCode
mutableStateNames += varName
varName
} else {
val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
@ -930,6 +943,15 @@ class CodegenContext {
// inline execution if only one block
blocks.head
} else {
if (Utils.isTesting) {
// Passing global variables to the split method is dangerous, as any mutating to it is
// ignored and may lead to unexpected behavior.
arguments.foreach { case (_, name) =>
assert(!mutableStateNames.contains(name),
s"split function argument $name cannot be a global variable.")
}
}
val func = freshName(funcName)
val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
val functions = blocks.zipWithIndex.map { case (body, i) =>

View file

@ -190,7 +190,7 @@ case class CaseWhen(
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult")
ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
// these blocks are meant to be inside a
// do {
@ -205,7 +205,7 @@ case class CaseWhen(
|if (!${cond.isNull} && ${cond.value}) {
| ${res.code}
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
| $tmpResult = ${res.value};
| ${ev.value} = ${res.value};
| continue;
|}
""".stripMargin
@ -216,7 +216,7 @@ case class CaseWhen(
s"""
|${res.code}
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
|$tmpResult = ${res.value};
|${ev.value} = ${res.value};
""".stripMargin
}
@ -264,13 +264,11 @@ case class CaseWhen(
ev.copy(code =
s"""
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|$tmpResult = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|// TRUE if any condition is met and the result is null, or no any condition is met.
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
""".stripMargin)
}
}

View file

@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull")
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
s"""
|${eval.code}
|if (!${eval.isNull}) {
| $tmpIsNull = false;
| ${ev.isNull} = false;
| ${ev.value} = ${eval.value};
| continue;
|}
@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
foldFunctions = _.map { funcCall =>
s"""
|${ev.value} = $funcCall;
|if (!$tmpIsNull) {
|if (!${ev.isNull}) {
| continue;
|}
""".stripMargin
@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|$tmpIsNull = true;
|${ev.isNull} = true;
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}

View file

@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
| $tmpResult = 0;
| $tmpResult = $NOT_MATCHED;
| $javaDataType $valueArg = ${valueGen.value};
| do {
| $codes