[SPARK-14637][SQL] object expressions cleanup

## What changes were proposed in this pull request?

Simplify and clean up some object expressions:

1. simplify the logic to handle `propagateNull`
2. add `propagateNull` parameter to `Invoke`
3. simplify the unbox logic in `Invoke`
4. other minor cleanup

TODO: simplify `MapObjects`

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12399 from cloud-fan/object.
This commit is contained in:
Wenchen Fan 2016-05-02 10:21:14 -07:00 committed by Michael Armbrust
parent 214d1be4fd
commit 0513c3ac93

View file

@ -64,33 +64,29 @@ case class StaticInvoke(
val argGen = arguments.map(_.genCode(ctx)) val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ") val argString = argGen.map(_.value).mkString(", ")
if (propagateNull) { val callFunc = s"$objectName.$functionName($argString)"
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" val setIsNull = if (propagateNull && arguments.nonEmpty) {
ev.copy(code = s""" s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
${argGen.map(_.code).mkString("\n")}
boolean ${ev.isNull} = !$argsNonNull;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $objectName.$functionName($argString);
$objNullCheck
}
""")
} else { } else {
ev.copy(code = s""" s"boolean ${ev.isNull} = false;"
${argGen.map(_.code).mkString("\n")}
$javaType ${ev.value} = $objectName.$functionName($argString);
final boolean ${ev.isNull} = ${ev.value} == null;
""")
} }
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val code = s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
} }
} }
@ -111,7 +107,8 @@ case class Invoke(
targetObject: Expression, targetObject: Expression,
functionName: String, functionName: String,
dataType: DataType, dataType: DataType,
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
override def nullable: Boolean = true override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments override def children: Seq[Expression] = targetObject +: arguments
@ -130,60 +127,53 @@ case class Invoke(
case _ => None case _ => None
} }
lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
s"((java.lang.Long)$s).longValue()"
case (FloatType, "java.lang.Object") => (s: String) =>
s"((java.lang.Float)$s).floatValue()"
case (ShortType, "java.lang.Object") => (s: String) =>
s"((java.lang.Short)$s).shortValue()"
case (ByteType, "java.lang.Object") => (s: String) =>
s"((java.lang.Byte)$s).byteValue()"
case (DoubleType, "java.lang.Object") => (s: String) =>
s"((java.lang.Double)$s).doubleValue()"
case (BooleanType, "java.lang.Object") => (s: String) =>
s"((java.lang.Boolean)$s).booleanValue()"
case _ => identity[String] _
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType) val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx) val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx)) val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ") val argString = argGen.map(_.value).mkString(", ")
// If the function can return null, we do an extra check to make sure our null bit is still set val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
// correctly. s"${obj.value}.$functionName($argString)"
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"boolean ${ev.isNull} = ${ev.value} == null;"
} else { } else {
ev.isNull = obj.isNull s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
""
} }
val value = unboxer(s"${obj.value}.$functionName($argString)") val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
} else { } else {
s""" s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)}; $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try { try {
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
} catch (Exception e) { } catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e); org.apache.spark.unsafe.Platform.throwException(e);
} }
""" """
} }
ev.copy(code = s""" // If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val code = s"""
${obj.code} ${obj.code}
${argGen.map(_.code).mkString("\n")} ${argGen.map(_.code).mkString("\n")}
$setIsNull
$evaluate $evaluate
$objNullCheck $postNullCheck
""") """
ev.copy(code = code)
} }
override def toString: String = s"$targetObject.$functionName" override def toString: String = s"$targetObject.$functionName"
@ -246,11 +236,13 @@ case class NewInstance(
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
val setup = var isNull = ev.isNull
s""" val setIsNull = if (propagateNull && arguments.nonEmpty) {
${argGen.map(_.code).mkString("\n")} s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
${outer.map(_.code).getOrElse("")} } else {
""".stripMargin isNull = "false"
""
}
val constructorCall = outer.map { gen => val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}($argString)""" s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
@ -258,27 +250,13 @@ case class NewInstance(
s"new $className($argString)" s"new $className($argString)"
} }
if (propagateNull && argGen.nonEmpty) { val code = s"""
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" ${argGen.map(_.code).mkString("\n")}
${outer.map(_.code).getOrElse("")}
ev.copy(code = s""" $setIsNull
$setup final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
boolean ${ev.isNull} = true; ev.copy(code = code, isNull = isNull)
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
${ev.value} = $constructorCall;
${ev.isNull} = false;
}
""")
} else {
ev.copy(code = s"""
$setup
final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
""")
}
} }
override def toString: String = s"newInstance($cls)" override def toString: String = s"newInstance($cls)"
@ -306,13 +284,14 @@ case class UnwrapOption(
val javaType = ctx.javaType(dataType) val javaType = ctx.javaType(dataType)
val inputObject = child.genCode(ctx) val inputObject = child.genCode(ctx)
ev.copy(code = s""" val code = s"""
${inputObject.code} ${inputObject.code}
boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} = $javaType ${ev.value} =
${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
""") """
ev.copy(code = code)
} }
} }
@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx) val inputObject = child.genCode(ctx)
ev.copy(code = s""" val code = s"""
${inputObject.code} ${inputObject.code}
boolean ${ev.isNull} = false;
scala.Option ${ev.value} = scala.Option ${ev.value} =
${inputObject.isNull} ? ${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
""") """
ev.copy(code = code, isNull = "false")
} }
} }
@ -474,7 +453,7 @@ case class MapObjects private(
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
} }
ev.copy(code = s""" val code = s"""
${genInputData.code} ${genInputData.code}
boolean ${ev.isNull} = ${genInputData.value} == null; boolean ${ev.isNull} = ${genInputData.value} == null;
@ -504,7 +483,8 @@ case class MapObjects private(
${ev.isNull} = false; ${ev.isNull} = false;
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
} }
""") """
ev.copy(code = code)
} }
} }
@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
} }
""" """
} }
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val schemaField = ctx.addReferenceObj("schema", schema) val schemaField = ctx.addReferenceObj("schema", schema)
ev.copy(code = s"""
boolean ${ev.isNull} = false; val code = s"""
$values = new Object[${children.size}]; $values = new Object[${children.size}];
$childrenCode $childrenCode
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
""") """
ev.copy(code = code, isNull = "false")
} }
} }
@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
// Code to serialize. // Code to serialize.
val input = child.genCode(ctx) val input = child.genCode(ctx)
ev.copy(code = s""" val javaType = ctx.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
val code = s"""
${input.code} ${input.code}
final boolean ${ev.isNull} = ${input.isNull}; final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """
if (!${ev.isNull}) { ev.copy(code = code, isNull = input.isNull)
${ev.value} = $serializer.serialize(${input.value}, null).array();
}
""")
} }
override def dataType: DataType = BinaryType override def dataType: DataType = BinaryType
@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
serializer, serializer,
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
// Code to serialize. // Code to deserialize.
val input = child.genCode(ctx) val input = child.genCode(ctx)
ev.copy(code = s""" val javaType = ctx.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
val code = s"""
${input.code} ${input.code}
final boolean ${ev.isNull} = ${input.isNull}; final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; """
if (!${ev.isNull}) { ev.copy(code = code, isNull = input.isNull)
${ev.value} = (${ctx.javaType(dataType)})
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
}
""")
} }
override def dataType: DataType = ObjectType(tag.runtimeClass) override def dataType: DataType = ObjectType(tag.runtimeClass)
@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
""" """
} }
ev.isNull = instanceGen.isNull val code = s"""
ev.value = instanceGen.value
ev.copy(code = s"""
${instanceGen.code} ${instanceGen.code}
if (!${instanceGen.isNull}) { if (!${instanceGen.isNull}) {
${initialize.mkString("\n")} ${initialize.mkString("\n")}
} }
""") """
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
} }
} }
@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " + "please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)." "(e.g. java.lang.Integer instead of int/scala.Int)."
val idx = ctx.references.length val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ctx.references += errMsg
ExprCode(code = s""" val code = s"""
${childGen.code} ${childGen.code}
if (${childGen.isNull}) { if (${childGen.isNull}) {
throw new RuntimeException((String) references[$idx]); throw new RuntimeException(this.$errMsgField);
}""", isNull = "false", value = childGen.value) }
"""
ev.copy(code = code, isNull = "false", value = childGen.value)
} }
} }