[SPARK-13147] [SQL] improve readability of generated code

1. try to avoid the suffix (unique id)
2. remove the comment if there is no code generated.
3. re-arrange the order of functions
4. trop the new line for inlined blocks.

Author: Davies Liu <davies@databricks.com>

Closes #11032 from davies/better_suffix.
This commit is contained in:
Davies Liu 2016-02-02 22:13:10 -08:00 committed by Davies Liu
parent 335f10edad
commit e86f8f63bf
7 changed files with 63 additions and 39 deletions

View file

@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
val value = ctx.freshName("value")
val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
if (ve.code != "") {
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
} else {
ve
}
}
}

View file

@ -156,7 +156,11 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"
private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
* The map from a variable name to it's next ID.
*/
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1
/**
* A prefix used to generate fresh name.
@ -164,16 +168,21 @@ class CodegenContext {
var freshNamePrefix = ""
/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
def freshName(name: String): String = {
if (freshNamePrefix == "") {
s"$name${curId.getAndIncrement}"
def freshName(name: String): String = synchronized {
val fullName = if (freshNamePrefix == "") {
name
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
s"${freshNamePrefix}_$name"
}
if (freshNameIds.contains(fullName)) {
val id = freshNameIds(fullName)
freshNameIds(fullName) = id + 1
s"$fullName$id"
} else {
freshNameIds += fullName -> 1
fullName
}
}

View file

@ -173,22 +173,26 @@ case class GetArrayStructFields(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
val n = ctx.freshName("n")
val values = ctx.freshName("values")
val j = ctx.freshName("j")
val row = ctx.freshName("row")
s"""
final int n = $eval.numElements();
final Object[] values = new Object[n];
for (int j = 0; j < n; j++) {
if ($eval.isNullAt(j)) {
values[j] = null;
final int $n = $eval.numElements();
final Object[] $values = new Object[$n];
for (int $j = 0; $j < $n; $j++) {
if ($eval.isNullAt($j)) {
$values[$j] = null;
} else {
final InternalRow row = $eval.getStruct(j, $numFields);
if (row.isNullAt($ordinal)) {
values[j] = null;
final InternalRow $row = $eval.getStruct($j, $numFields);
if ($row.isNullAt($ordinal)) {
$values[$j] = null;
} else {
values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
}
}
}
${ev.value} = new $arrayClass(values);
${ev.value} = new $arrayClass($values);
"""
})
}
@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("index")
s"""
final int index = (int) $eval2;
if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
final int $index = (int) $eval2;
if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
}
"""
})

View file

@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| }
""".stripMargin
}
@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
private Object[] references;
${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}
public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
this.references = references;
${ctx.initMutableStates()}
}
${ctx.declareAddedFunctions()}
protected void processNext() throws java.io.IOException {
$code
${code.trim}
}
}
"""

View file

@ -211,9 +211,9 @@ case class TungstenAggregate(
| $doAgg();
|
| // output the result
| $genResult
| ${genResult.trim}
|
| ${consume(ctx, resultVars)}
| ${consume(ctx, resultVars).trim}
| }
""".stripMargin
}
@ -242,9 +242,9 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
| ${aggVals.map(_.code).mkString("\n")}
| ${aggVals.map(_.code).mkString("\n").trim}
| // update aggregation buffer
| ${updates.mkString("")}
| ${updates.mkString("\n").trim}
""".stripMargin
}
@ -523,7 +523,7 @@ case class TungstenAggregate(
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
${keyCode.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
@ -547,9 +547,9 @@ case class TungstenAggregate(
$incCounter
// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
${evals.map(_.code).mkString("\n").trim}
// update aggregate buffer
${updates.mkString("\n")}
${updates.mkString("\n").trim}
"""
}

View file

@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
val eval = expr.gen(ctx)
val nullCheck = if (expr.nullable) {
s"!${eval.isNull} &&"
} else {
s""
}
s"""
| ${eval.code}
| if (!${eval.isNull} && ${eval.value}) {
| if ($nullCheck ${eval.value}) {
| ${consume(ctx, ctx.currentVars)}
| }
""".stripMargin

View file

@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// These benchmark are skipped in normal build
ignore("benchmark") {
// testWholeStage(200 << 20)
// testStddev(20 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
}