From 8a0ed5a5ee64a6e854c516f80df5a9729435479b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 22 Dec 2017 00:21:27 +0800 Subject: [PATCH] [SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions() ## What changes were proposed in this pull request? Passing global variables to the split method is dangerous, as any mutating to it is ignored and may lead to unexpected behavior. To prevent this, one approach is to make sure no expression would output global variables: Localizing lifetime of mutable states in expressions. Another approach is, when calling `ctx.splitExpression`, make sure we don't use children's output as parameter names. Approach 1 is actually hard to do, as we need to check all expressions and operators that support whole-stage codegen. Approach 2 is easier as the callers of `ctx.splitExpressions` are not too many. Besides, approach 2 is more flexible, as children's output may be other stuff that can't be parameter name: literal, inlined statement(a + 1), etc. close https://github.com/apache/spark/pull/19865 close https://github.com/apache/spark/pull/19938 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20021 from cloud-fan/codegen. --- .../sql/catalyst/expressions/arithmetic.scala | 18 +++++------ .../expressions/codegen/CodeGenerator.scala | 32 ++++++++++++++++--- .../expressions/conditionalExpressions.scala | 8 ++--- .../expressions/nullExpressions.scala | 9 +++--- .../sql/catalyst/expressions/predicates.scala | 2 +- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d3a8cb5804..8bb14598a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } @@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 41a920ba3d..9adf632ddc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -128,7 +128,7 @@ class CodegenContext { * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling * `Expression.genCode`. */ - final var INPUT_ROW = "i" + var INPUT_ROW = "i" /** * Holding a list of generated columns as input of current operator, will be used by @@ -146,22 +146,30 @@ class CodegenContext { * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. + * + * Exposed for tests only. */ - val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = + private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty[(String, String)] /** * The mapping between mutable state types and corrseponding compacted arrays. * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates * the compacted arrays for the mutable states with the same java type. + * + * Exposed for tests only. */ - val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = + private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] // An array holds the code that will initialize each state - val mutableStateInitCode: mutable.ArrayBuffer[String] = + // Exposed for tests only. + private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + // Tracks the names of all the mutable states. + private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty + /** * This class holds a set of names of mutableStateArrays that is used for compacting mutable * states for a certain type, and holds the next available slot of the current compacted array. @@ -172,7 +180,11 @@ class CodegenContext { private[this] var currentIndex = 0 - private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + private def createNewArray() = { + val newArrayName = freshName("mutableStateArray") + mutableStateNames += newArrayName + arrayNames.append(newArrayName) + } def getCurrentIndex: Int = currentIndex @@ -241,6 +253,7 @@ class CodegenContext { val initCode = initFunc(varName) inlinedMutableStates += ((javaType, varName)) mutableStateInitCode += initCode + mutableStateNames += varName varName } else { val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) @@ -930,6 +943,15 @@ class CodegenContext { // inline execution if only one block blocks.head } else { + if (Utils.isTesting) { + // Passing global variables to the split method is dangerous, as any mutating to it is + // ignored and may lead to unexpected behavior. + arguments.foreach { case (_, name) => + assert(!mutableStateNames.contains(name), + s"split function argument $name cannot be a global variable.") + } + } + val func = freshName(funcName) val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") val functions = blocks.zipWithIndex.map { case (body, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1a9b68222a..142dfb02be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,7 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") + ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -205,7 +205,7 @@ case class CaseWhen( |if (!${cond.isNull} && ${cond.value}) { | ${res.code} | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - | $tmpResult = ${res.value}; + | ${ev.value} = ${res.value}; | continue; |} """.stripMargin @@ -216,7 +216,7 @@ case class CaseWhen( s""" |${res.code} |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - |$tmpResult = ${res.value}; + |${ev.value} = ${res.value}; """.stripMargin } @@ -264,13 +264,11 @@ case class CaseWhen( ev.copy(code = s""" |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; - |$tmpResult = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); |// TRUE if any condition is met and the result is null, or no any condition is met. |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); - |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b4f895fffd..470d5da041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { foldFunctions = _.map { funcCall => s""" |${ev.value} = $funcCall; - |if (!$tmpIsNull) { + |if (!${ev.isNull}) { | continue; |} """.stripMargin @@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac9f56f78e..f4ee3d10f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { - | $tmpResult = 0; + | $tmpResult = $NOT_MATCHED; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes