[SPARK-14637][SQL] object expressions cleanup
## What changes were proposed in this pull request? Simplify and clean up some object expressions: 1. simplify the logic to handle `propagateNull` 2. add `propagateNull` parameter to `Invoke` 3. simplify the unbox logic in `Invoke` 4. other minor cleanup TODO: simplify `MapObjects` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12399 from cloud-fan/object.
This commit is contained in:
parent
214d1be4fd
commit
0513c3ac93
|
@ -64,33 +64,29 @@ case class StaticInvoke(
|
|||
val argGen = arguments.map(_.genCode(ctx))
|
||||
val argString = argGen.map(_.value).mkString(", ")
|
||||
|
||||
if (propagateNull) {
|
||||
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
|
||||
val callFunc = s"$objectName.$functionName($argString)"
|
||||
|
||||
val setIsNull = if (propagateNull && arguments.nonEmpty) {
|
||||
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
|
||||
} else {
|
||||
s"boolean ${ev.isNull} = false;"
|
||||
}
|
||||
|
||||
// If the function can return null, we do an extra check to make sure our null bit is still set
|
||||
// correctly.
|
||||
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
|
||||
s"${ev.isNull} = ${ev.value} == null;"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
|
||||
ev.copy(code = s"""
|
||||
val code = s"""
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
|
||||
boolean ${ev.isNull} = !$argsNonNull;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|
||||
if ($argsNonNull) {
|
||||
${ev.value} = $objectName.$functionName($argString);
|
||||
$objNullCheck
|
||||
}
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
|
||||
$javaType ${ev.value} = $objectName.$functionName($argString);
|
||||
final boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""")
|
||||
}
|
||||
$setIsNull
|
||||
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
|
||||
$postNullCheck
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,7 +107,8 @@ case class Invoke(
|
|||
targetObject: Expression,
|
||||
functionName: String,
|
||||
dataType: DataType,
|
||||
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
|
||||
arguments: Seq[Expression] = Nil,
|
||||
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def children: Seq[Expression] = targetObject +: arguments
|
||||
|
@ -130,60 +127,53 @@ case class Invoke(
|
|||
case _ => None
|
||||
}
|
||||
|
||||
lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
|
||||
case (IntegerType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Integer)$s).intValue()"
|
||||
case (LongType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Long)$s).longValue()"
|
||||
case (FloatType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Float)$s).floatValue()"
|
||||
case (ShortType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Short)$s).shortValue()"
|
||||
case (ByteType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Byte)$s).byteValue()"
|
||||
case (DoubleType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Double)$s).doubleValue()"
|
||||
case (BooleanType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Boolean)$s).booleanValue()"
|
||||
case _ => identity[String] _
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val obj = targetObject.genCode(ctx)
|
||||
val argGen = arguments.map(_.genCode(ctx))
|
||||
val argString = argGen.map(_.value).mkString(", ")
|
||||
|
||||
// If the function can return null, we do an extra check to make sure our null bit is still set
|
||||
// correctly.
|
||||
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
|
||||
s"boolean ${ev.isNull} = ${ev.value} == null;"
|
||||
val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
|
||||
s"${obj.value}.$functionName($argString)"
|
||||
} else {
|
||||
ev.isNull = obj.isNull
|
||||
""
|
||||
s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
|
||||
}
|
||||
|
||||
val value = unboxer(s"${obj.value}.$functionName($argString)")
|
||||
val setIsNull = if (propagateNull && arguments.nonEmpty) {
|
||||
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
|
||||
} else {
|
||||
s"boolean ${ev.isNull} = ${obj.isNull};"
|
||||
}
|
||||
|
||||
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
|
||||
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
|
||||
s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
|
||||
} else {
|
||||
s"""
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
try {
|
||||
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
|
||||
${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
|
||||
} catch (Exception e) {
|
||||
org.apache.spark.unsafe.Platform.throwException(e);
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
ev.copy(code = s"""
|
||||
// If the function can return null, we do an extra check to make sure our null bit is still set
|
||||
// correctly.
|
||||
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
|
||||
s"${ev.isNull} = ${ev.value} == null;"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
${obj.code}
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
$setIsNull
|
||||
$evaluate
|
||||
$objNullCheck
|
||||
""")
|
||||
$postNullCheck
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
|
||||
override def toString: String = s"$targetObject.$functionName"
|
||||
|
@ -246,11 +236,13 @@ case class NewInstance(
|
|||
|
||||
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
|
||||
|
||||
val setup =
|
||||
s"""
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
${outer.map(_.code).getOrElse("")}
|
||||
""".stripMargin
|
||||
var isNull = ev.isNull
|
||||
val setIsNull = if (propagateNull && arguments.nonEmpty) {
|
||||
s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
|
||||
} else {
|
||||
isNull = "false"
|
||||
""
|
||||
}
|
||||
|
||||
val constructorCall = outer.map { gen =>
|
||||
s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
|
||||
|
@ -258,27 +250,13 @@ case class NewInstance(
|
|||
s"new $className($argString)"
|
||||
}
|
||||
|
||||
if (propagateNull && argGen.nonEmpty) {
|
||||
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
|
||||
|
||||
ev.copy(code = s"""
|
||||
$setup
|
||||
|
||||
boolean ${ev.isNull} = true;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
if ($argsNonNull) {
|
||||
${ev.value} = $constructorCall;
|
||||
${ev.isNull} = false;
|
||||
}
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
$setup
|
||||
|
||||
final $javaType ${ev.value} = $constructorCall;
|
||||
final boolean ${ev.isNull} = false;
|
||||
""")
|
||||
}
|
||||
val code = s"""
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
${outer.map(_.code).getOrElse("")}
|
||||
$setIsNull
|
||||
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
|
||||
"""
|
||||
ev.copy(code = code, isNull = isNull)
|
||||
}
|
||||
|
||||
override def toString: String = s"newInstance($cls)"
|
||||
|
@ -306,13 +284,14 @@ case class UnwrapOption(
|
|||
val javaType = ctx.javaType(dataType)
|
||||
val inputObject = child.genCode(ctx)
|
||||
|
||||
ev.copy(code = s"""
|
||||
val code = s"""
|
||||
${inputObject.code}
|
||||
|
||||
boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty();
|
||||
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
|
||||
$javaType ${ev.value} =
|
||||
${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get();
|
||||
""")
|
||||
${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get();
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType)
|
|||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val inputObject = child.genCode(ctx)
|
||||
|
||||
ev.copy(code = s"""
|
||||
val code = s"""
|
||||
${inputObject.code}
|
||||
|
||||
boolean ${ev.isNull} = false;
|
||||
scala.Option ${ev.value} =
|
||||
${inputObject.isNull} ?
|
||||
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
|
||||
""")
|
||||
"""
|
||||
ev.copy(code = code, isNull = "false")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -474,7 +453,7 @@ case class MapObjects private(
|
|||
s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
|
||||
}
|
||||
|
||||
ev.copy(code = s"""
|
||||
val code = s"""
|
||||
${genInputData.code}
|
||||
|
||||
boolean ${ev.isNull} = ${genInputData.value} == null;
|
||||
|
@ -504,7 +483,8 @@ case class MapObjects private(
|
|||
${ev.isNull} = false;
|
||||
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
|
||||
}
|
||||
""")
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
|
|||
}
|
||||
"""
|
||||
}
|
||||
|
||||
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
|
||||
val schemaField = ctx.addReferenceObj("schema", schema)
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
|
||||
val code = s"""
|
||||
$values = new Object[${children.size}];
|
||||
$childrenCode
|
||||
final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
|
||||
""")
|
||||
"""
|
||||
ev.copy(code = code, isNull = "false")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
|
|||
|
||||
// Code to serialize.
|
||||
val input = child.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val serialize = s"$serializer.serialize(${input.value}, null).array()"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
final boolean ${ev.isNull} = ${input.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $serializer.serialize(${input.value}, null).array();
|
||||
}
|
||||
""")
|
||||
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
|
||||
"""
|
||||
ev.copy(code = code, isNull = input.isNull)
|
||||
}
|
||||
|
||||
override def dataType: DataType = BinaryType
|
||||
|
@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
|
|||
serializer,
|
||||
s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();")
|
||||
|
||||
// Code to serialize.
|
||||
// Code to deserialize.
|
||||
val input = child.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val deserialize =
|
||||
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
final boolean ${ev.isNull} = ${input.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = (${ctx.javaType(dataType)})
|
||||
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
|
||||
}
|
||||
""")
|
||||
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
|
||||
"""
|
||||
ev.copy(code = code, isNull = input.isNull)
|
||||
}
|
||||
|
||||
override def dataType: DataType = ObjectType(tag.runtimeClass)
|
||||
|
@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
|
|||
"""
|
||||
}
|
||||
|
||||
ev.isNull = instanceGen.isNull
|
||||
ev.value = instanceGen.value
|
||||
|
||||
ev.copy(code = s"""
|
||||
val code = s"""
|
||||
${instanceGen.code}
|
||||
if (!${instanceGen.isNull}) {
|
||||
${initialize.mkString("\n")}
|
||||
}
|
||||
""")
|
||||
"""
|
||||
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
|
|||
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
|
||||
"please try to use scala.Option[_] or other nullable types " +
|
||||
"(e.g. java.lang.Integer instead of int/scala.Int)."
|
||||
val idx = ctx.references.length
|
||||
ctx.references += errMsg
|
||||
ExprCode(code = s"""
|
||||
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
|
||||
|
||||
val code = s"""
|
||||
${childGen.code}
|
||||
|
||||
if (${childGen.isNull}) {
|
||||
throw new RuntimeException((String) references[$idx]);
|
||||
}""", isNull = "false", value = childGen.value)
|
||||
throw new RuntimeException(this.$errMsgField);
|
||||
}
|
||||
"""
|
||||
ev.copy(code = code, isNull = "false", value = childGen.value)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue