[SPARK-14637][SQL] object expressions cleanup

## What changes were proposed in this pull request?

Simplify and clean up some object expressions:

1. simplify the logic to handle `propagateNull`
2. add `propagateNull` parameter to `Invoke`
3. simplify the unbox logic in `Invoke`
4. other minor cleanup

TODO: simplify `MapObjects`

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12399 from cloud-fan/object.
This commit is contained in:
Wenchen Fan 2016-05-02 10:21:14 -07:00 committed by Michael Armbrust
parent 214d1be4fd
commit 0513c3ac93

View file

@ -64,33 +64,29 @@ case class StaticInvoke(
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
if (propagateNull) {
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val callFunc = s"$objectName.$functionName($argString)"
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
ev.copy(code = s"""
${argGen.map(_.code).mkString("\n")}
boolean ${ev.isNull} = !$argsNonNull;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $objectName.$functionName($argString);
$objNullCheck
}
""")
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
ev.copy(code = s"""
${argGen.map(_.code).mkString("\n")}
$javaType ${ev.value} = $objectName.$functionName($argString);
final boolean ${ev.isNull} = ${ev.value} == null;
""")
s"boolean ${ev.isNull} = false;"
}
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val code = s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
}
}
@ -111,7 +107,8 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
@ -130,60 +127,53 @@ case class Invoke(
case _ => None
}
lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
s"((java.lang.Long)$s).longValue()"
case (FloatType, "java.lang.Object") => (s: String) =>
s"((java.lang.Float)$s).floatValue()"
case (ShortType, "java.lang.Object") => (s: String) =>
s"((java.lang.Short)$s).shortValue()"
case (ByteType, "java.lang.Object") => (s: String) =>
s"((java.lang.Byte)$s).byteValue()"
case (DoubleType, "java.lang.Object") => (s: String) =>
s"((java.lang.Double)$s).doubleValue()"
case (BooleanType, "java.lang.Object") => (s: String) =>
s"((java.lang.Boolean)$s).booleanValue()"
case _ => identity[String] _
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"boolean ${ev.isNull} = ${ev.value} == null;"
val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
s"${obj.value}.$functionName($argString)"
} else {
ev.isNull = obj.isNull
""
s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
}
val value = unboxer(s"${obj.value}.$functionName($argString)")
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
} else {
s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try {
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
} catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e);
}
"""
}
ev.copy(code = s"""
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
$evaluate
$objNullCheck
""")
$postNullCheck
"""
ev.copy(code = code)
}
override def toString: String = s"$targetObject.$functionName"
@ -246,11 +236,13 @@ case class NewInstance(
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
val setup =
s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
""".stripMargin
var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
isNull = "false"
""
}
val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
@ -258,27 +250,13 @@ case class NewInstance(
s"new $className($argString)"
}
if (propagateNull && argGen.nonEmpty) {
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
ev.copy(code = s"""
$setup
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $constructorCall;
${ev.isNull} = false;
}
""")
} else {
ev.copy(code = s"""
$setup
final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
""")
}
val code = s"""
${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code, isNull = isNull)
}
override def toString: String = s"newInstance($cls)"
@ -306,13 +284,14 @@ case class UnwrapOption(
val javaType = ctx.javaType(dataType)
val inputObject = child.genCode(ctx)
ev.copy(code = s"""
val code = s"""
${inputObject.code}
boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} =
${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
""")
${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
"""
ev.copy(code = code)
}
}
@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
ev.copy(code = s"""
val code = s"""
${inputObject.code}
boolean ${ev.isNull} = false;
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
""")
"""
ev.copy(code = code, isNull = "false")
}
}
@ -474,7 +453,7 @@ case class MapObjects private(
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
ev.copy(code = s"""
val code = s"""
${genInputData.code}
boolean ${ev.isNull} = ${genInputData.value} == null;
@ -504,7 +483,8 @@ case class MapObjects private(
${ev.isNull} = false;
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
}
""")
"""
ev.copy(code = code)
}
}
@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
}
"""
}
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema)
ev.copy(code = s"""
boolean ${ev.isNull} = false;
val code = s"""
$values = new Object[${children.size}];
$childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
""")
"""
ev.copy(code = code, isNull = "false")
}
}
@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
// Code to serialize.
val input = child.genCode(ctx)
ev.copy(code = s"""
val javaType = ctx.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
val code = s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $serializer.serialize(${input.value}, null).array();
}
""")
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
override def dataType: DataType = BinaryType
@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
serializer,
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
// Code to serialize.
// Code to deserialize.
val input = child.genCode(ctx)
ev.copy(code = s"""
val javaType = ctx.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
val code = s"""
${input.code}
final boolean ${ev.isNull} = ${input.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.javaType(dataType)})
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
""")
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
override def dataType: DataType = ObjectType(tag.runtimeClass)
@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
"""
}
ev.isNull = instanceGen.isNull
ev.value = instanceGen.value
ev.copy(code = s"""
val code = s"""
${instanceGen.code}
if (!${instanceGen.isNull}) {
${initialize.mkString("\n")}
}
""")
"""
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
}
}
@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
val idx = ctx.references.length
ctx.references += errMsg
ExprCode(code = s"""
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val code = s"""
${childGen.code}
if (${childGen.isNull}) {
throw new RuntimeException((String) references[$idx]);
}""", isNull = "false", value = childGen.value)
throw new RuntimeException(this.$errMsgField);
}
"""
ev.copy(code = code, isNull = "false", value = childGen.value)
}
}