[SPARK-13147] [SQL] improve readability of generated code
1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu <davies@databricks.com> Closes #11032 from davies/better_suffix.
This commit is contained in:
parent
335f10edad
commit
e86f8f63bf
|
@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
val value = ctx.freshName("value")
|
||||
val ve = ExprCode("", isNull, value)
|
||||
ve.code = genCode(ctx, ve)
|
||||
// Add `this` in the comment.
|
||||
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
|
||||
if (ve.code != "") {
|
||||
// Add `this` in the comment.
|
||||
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
|
||||
} else {
|
||||
ve
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -156,7 +156,11 @@ class CodegenContext {
|
|||
/** The variable name of the input row in generated code. */
|
||||
final var INPUT_ROW = "i"
|
||||
|
||||
private val curId = new java.util.concurrent.atomic.AtomicInteger()
|
||||
/**
|
||||
* The map from a variable name to it's next ID.
|
||||
*/
|
||||
private val freshNameIds = new mutable.HashMap[String, Int]
|
||||
freshNameIds += INPUT_ROW -> 1
|
||||
|
||||
/**
|
||||
* A prefix used to generate fresh name.
|
||||
|
@ -164,16 +168,21 @@ class CodegenContext {
|
|||
var freshNamePrefix = ""
|
||||
|
||||
/**
|
||||
* Returns a term name that is unique within this instance of a `CodeGenerator`.
|
||||
*
|
||||
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
|
||||
* function.)
|
||||
* Returns a term name that is unique within this instance of a `CodegenContext`.
|
||||
*/
|
||||
def freshName(name: String): String = {
|
||||
if (freshNamePrefix == "") {
|
||||
s"$name${curId.getAndIncrement}"
|
||||
def freshName(name: String): String = synchronized {
|
||||
val fullName = if (freshNamePrefix == "") {
|
||||
name
|
||||
} else {
|
||||
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
|
||||
s"${freshNamePrefix}_$name"
|
||||
}
|
||||
if (freshNameIds.contains(fullName)) {
|
||||
val id = freshNameIds(fullName)
|
||||
freshNameIds(fullName) = id + 1
|
||||
s"$fullName$id"
|
||||
} else {
|
||||
freshNameIds += fullName -> 1
|
||||
fullName
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -173,22 +173,26 @@ case class GetArrayStructFields(
|
|||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
val arrayClass = classOf[GenericArrayData].getName
|
||||
nullSafeCodeGen(ctx, ev, eval => {
|
||||
val n = ctx.freshName("n")
|
||||
val values = ctx.freshName("values")
|
||||
val j = ctx.freshName("j")
|
||||
val row = ctx.freshName("row")
|
||||
s"""
|
||||
final int n = $eval.numElements();
|
||||
final Object[] values = new Object[n];
|
||||
for (int j = 0; j < n; j++) {
|
||||
if ($eval.isNullAt(j)) {
|
||||
values[j] = null;
|
||||
final int $n = $eval.numElements();
|
||||
final Object[] $values = new Object[$n];
|
||||
for (int $j = 0; $j < $n; $j++) {
|
||||
if ($eval.isNullAt($j)) {
|
||||
$values[$j] = null;
|
||||
} else {
|
||||
final InternalRow row = $eval.getStruct(j, $numFields);
|
||||
if (row.isNullAt($ordinal)) {
|
||||
values[j] = null;
|
||||
final InternalRow $row = $eval.getStruct($j, $numFields);
|
||||
if ($row.isNullAt($ordinal)) {
|
||||
$values[$j] = null;
|
||||
} else {
|
||||
values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
|
||||
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
|
||||
}
|
||||
}
|
||||
}
|
||||
${ev.value} = new $arrayClass(values);
|
||||
${ev.value} = new $arrayClass($values);
|
||||
"""
|
||||
})
|
||||
}
|
||||
|
@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
|
||||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
||||
val index = ctx.freshName("index")
|
||||
s"""
|
||||
final int index = (int) $eval2;
|
||||
if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
|
||||
final int $index = (int) $eval2;
|
||||
if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
|
||||
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
|
||||
}
|
||||
"""
|
||||
})
|
||||
|
|
|
@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
|||
s"""
|
||||
| while (input.hasNext()) {
|
||||
| InternalRow $row = (InternalRow) input.next();
|
||||
| ${columns.map(_.code).mkString("\n")}
|
||||
| ${consume(ctx, columns)}
|
||||
| ${columns.map(_.code).mkString("\n").trim}
|
||||
| ${consume(ctx, columns).trim}
|
||||
| }
|
||||
""".stripMargin
|
||||
}
|
||||
|
@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
|
||||
private Object[] references;
|
||||
${ctx.declareMutableStates()}
|
||||
${ctx.declareAddedFunctions()}
|
||||
|
||||
public GeneratedIterator(Object[] references) {
|
||||
this.references = references;
|
||||
${ctx.initMutableStates()}
|
||||
this.references = references;
|
||||
${ctx.initMutableStates()}
|
||||
}
|
||||
|
||||
${ctx.declareAddedFunctions()}
|
||||
|
||||
protected void processNext() throws java.io.IOException {
|
||||
$code
|
||||
${code.trim}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
|
@ -211,9 +211,9 @@ case class TungstenAggregate(
|
|||
| $doAgg();
|
||||
|
|
||||
| // output the result
|
||||
| $genResult
|
||||
| ${genResult.trim}
|
||||
|
|
||||
| ${consume(ctx, resultVars)}
|
||||
| ${consume(ctx, resultVars).trim}
|
||||
| }
|
||||
""".stripMargin
|
||||
}
|
||||
|
@ -242,9 +242,9 @@ case class TungstenAggregate(
|
|||
}
|
||||
s"""
|
||||
| // do aggregate
|
||||
| ${aggVals.map(_.code).mkString("\n")}
|
||||
| ${aggVals.map(_.code).mkString("\n").trim}
|
||||
| // update aggregation buffer
|
||||
| ${updates.mkString("")}
|
||||
| ${updates.mkString("\n").trim}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
|
@ -523,7 +523,7 @@ case class TungstenAggregate(
|
|||
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
|
||||
s"""
|
||||
// generate grouping key
|
||||
${keyCode.code}
|
||||
${keyCode.code.trim}
|
||||
UnsafeRow $buffer = null;
|
||||
if ($checkFallback) {
|
||||
// try to get the buffer from hash map
|
||||
|
@ -547,9 +547,9 @@ case class TungstenAggregate(
|
|||
$incCounter
|
||||
|
||||
// evaluate aggregate function
|
||||
${evals.map(_.code).mkString("\n")}
|
||||
${evals.map(_.code).mkString("\n").trim}
|
||||
// update aggregate buffer
|
||||
${updates.mkString("\n")}
|
||||
${updates.mkString("\n").trim}
|
||||
"""
|
||||
}
|
||||
|
||||
|
|
|
@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
|
|||
BindReferences.bindReference(condition, child.output))
|
||||
ctx.currentVars = input
|
||||
val eval = expr.gen(ctx)
|
||||
val nullCheck = if (expr.nullable) {
|
||||
s"!${eval.isNull} &&"
|
||||
} else {
|
||||
s""
|
||||
}
|
||||
s"""
|
||||
| ${eval.code}
|
||||
| if (!${eval.isNull} && ${eval.value}) {
|
||||
| if ($nullCheck ${eval.value}) {
|
||||
| ${consume(ctx, ctx.currentVars)}
|
||||
| }
|
||||
""".stripMargin
|
||||
|
|
|
@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
|||
// These benchmark are skipped in normal build
|
||||
ignore("benchmark") {
|
||||
// testWholeStage(200 << 20)
|
||||
// testStddev(20 << 20)
|
||||
// testStatFunctions(20 << 20)
|
||||
// testAggregateWithKey(20 << 20)
|
||||
// testBytesToBytesMap(1024 * 1024 * 50)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue