[SPARK-9154] [SQL] codegen StringFormat
Jira: https://issues.apache.org/jira/browse/SPARK-9154 Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7546 from tarekauel/SPARK-9154 and squashes the following commits: a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format
This commit is contained in:
parent
d45355ee22
commit
7f072c3d5e
|
@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
|
|||
/**
|
||||
* Returns the input formatted according do printf-style format strings
|
||||
*/
|
||||
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
|
||||
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
|
||||
|
||||
require(children.nonEmpty, "printf() should take at least 1 argument")
|
||||
|
||||
|
@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
|
|||
private def format: Expression = children(0)
|
||||
private def args: Seq[Expression] = children.tail
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)
|
||||
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val pattern = format.eval(input)
|
||||
if (pattern == null) {
|
||||
|
@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val pattern = children.head.gen(ctx)
|
||||
|
||||
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
|
||||
val argListCode = argListGen.map(_._2.code + "\n")
|
||||
|
||||
val argListString = argListGen.foldLeft("")((s, v) => {
|
||||
val nullSafeString =
|
||||
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
|
||||
// Java primitives get boxed in order to allow null values.
|
||||
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
|
||||
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
|
||||
} else {
|
||||
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
|
||||
}
|
||||
s + "," + nullSafeString
|
||||
})
|
||||
|
||||
val form = ctx.freshName("formatter")
|
||||
val formatter = classOf[java.util.Formatter].getName
|
||||
val sb = ctx.freshName("sb")
|
||||
val stringBuffer = classOf[StringBuffer].getName
|
||||
s"""
|
||||
${pattern.code}
|
||||
boolean ${ev.isNull} = ${pattern.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${argListCode.mkString}
|
||||
$stringBuffer $sb = new $stringBuffer();
|
||||
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
|
||||
$form.format(${pattern.primitive}.toString() $argListString);
|
||||
${ev.primitive} = UTF8String.fromString($sb.toString());
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
override def prettyName: String = "printf"
|
||||
}
|
||||
|
||||
|
|
|
@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
}
|
||||
|
||||
test("FORMAT") {
|
||||
val f = 'f.string.at(0)
|
||||
val d1 = 'd.int.at(1)
|
||||
val s1 = 's.int.at(2)
|
||||
|
||||
val row1 = create_row("aa%d%s", 12, "cc")
|
||||
val row2 = create_row(null, 12, "cc")
|
||||
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
|
||||
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
|
||||
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
|
||||
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
|
||||
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
|
||||
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
|
||||
|
||||
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
|
||||
checkEvaluation(StringFormat(f, d1, s1), null, row2)
|
||||
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
|
||||
checkEvaluation(
|
||||
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
|
||||
checkEvaluation(
|
||||
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
|
||||
}
|
||||
|
||||
test("INSTR") {
|
||||
|
|
|
@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
|
|||
checkAnswer(
|
||||
df.selectExpr("printf(a, b, c)"),
|
||||
Row("aa123cc"))
|
||||
|
||||
val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
|
||||
|
||||
checkAnswer(
|
||||
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
|
||||
Row("aa123cc", "aa123cc"))
|
||||
|
||||
checkAnswer(
|
||||
df2.selectExpr("printf(a, b, c)"),
|
||||
Row("aa123cc"))
|
||||
}
|
||||
|
||||
test("string instr function") {
|
||||
|
|
Loading…
Reference in a new issue