[SPARK-23583][SQL] Invoke should support interpreted execution

## What changes were proposed in this pull request?

This pr added interpreted execution for `Invoke`.

## How was this patch tested?

Added tests in `ObjectExpressionsSuite`.

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #20797 from kiszk/SPARK-28583.
This commit is contained in:
Kazuaki Ishizaki 2018-04-04 18:36:15 +02:00 committed by Herman van Hovell
parent 5197562afe
commit a35523653c
3 changed files with 163 additions and 6 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
"throws", "transient", "true", "try", "void", "volatile", "while")
val typeJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[Boolean],
ByteType -> classOf[Byte],
ShortType -> classOf[Short],
IntegerType -> classOf[Int],
LongType -> classOf[Long],
FloatType -> classOf[Float],
DoubleType -> classOf[Double],
StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
)
val typeBoxedJavaMapping = Map[DataType, Class[_]](
BooleanType -> classOf[java.lang.Boolean],
ByteType -> classOf[java.lang.Byte],
ShortType -> classOf[java.lang.Short],
IntegerType -> classOf[java.lang.Integer],
LongType -> classOf[java.lang.Long],
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
)
def dataTypeJavaClass(dt: DataType): Class[_] = {
dt match {
case _: DecimalType => classOf[Decimal]
case _: StructType => classOf[InternalRow]
case _: ArrayType => classOf[ArrayData]
case _: MapType => classOf[MapData]
case ObjectType(cls) => cls
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}
}
def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
if (arguments != Nil) {
arguments.map(e => dataTypeJavaClass(e.dataType))
} else {
Seq.empty
}
}
}
/**

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.Modifier
import java.lang.reflect.{Method, Modifier}
import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
@ -28,7 +28,7 @@ import scala.util.Try
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {
(argCode, argValues.mkString(", "), resultIsNull)
}
/**
* Evaluate each argument with a given row, invoke a method with a given object and arguments,
* and cast a return value if the return type can be mapped to a Java Boxed type
*
* @param obj the object for the method to be called. If null, perform s static method call
* @param method the method object to be called
* @param arguments the arguments used for the method call
* @param input the row used for evaluating arguments
* @param dataType the data type of the return object
* @return the return object of a method call
*/
def invoke(
obj: Any,
method: Method,
arguments: Seq[Expression],
input: InternalRow,
dataType: DataType): Any = {
val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
if (needNullCheck && args.exists(_ == null)) {
// return null if one of arguments is null
null
} else {
val ret = method.invoke(obj, args: _*)
val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
if (boxedClass.isDefined) {
boxedClass.get.cast(ret)
} else {
ret
}
}
}
}
/**
@ -264,12 +296,11 @@ case class Invoke(
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
@transient lazy val method = targetObject.dataType match {
@ -283,6 +314,21 @@ case class Invoke(
case _ => None
}
override def eval(input: InternalRow): Any = {
val obj = targetObject.eval(input)
if (obj == null) {
// return null if obj is null
null
} else {
val invokeMethod = if (method.isDefined) {
method.get
} else {
obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
}
invoke(obj, invokeMethod, arguments, input, dataType)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val obj = targetObject.genCode(ctx)

View file

@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
class InvokeTargetClass extends Serializable {
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
def filterPrimitiveInt(e: Int): Boolean = e > 0
def binOp(e1: Int, e2: Double): Double = e1 + e2
}
class InvokeTargetSubClass extends InvokeTargetClass {
override def binOp(e1: Int, e2: Double): Double = e1 - e2
}
class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}
test("SPARK-23583: Invoke should support interpreted execution") {
val targetObject = new InvokeTargetClass
val funcClass = classOf[InvokeTargetClass]
val funcObj = Literal.create(targetObject, ObjectType(funcClass))
val targetSubObject = new InvokeTargetSubClass
val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
val funcNullObj = Literal.create(null, ObjectType(funcClass))
val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))
checkObjectExprEvaluation(
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))
checkObjectExprEvaluation(
Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
false, InternalRow.fromSeq(Seq(-1)))
checkObjectExprEvaluation(
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
null, InternalRow.fromSeq(Seq(null)))
checkObjectExprEvaluation(
Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
checkObjectExprEvaluation(
Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))
checkObjectExprEvaluation(
Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
}
test("SPARK-23585: UnwrapOption should support interpreted execution") {
val cls = classOf[Option[Int]]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
}
// by scala values instead of catalyst values.
private def checkObjectExprEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val resolver = ResolveTimeZone(new SQLConf)
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
checkEvaluationWithoutCodegen(expr, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
checkEvaluationWithUnsafeProjection(
expr,
expected,
inputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
}
checkEvaluationWithOptimization(expr, expected, inputRow)
}
test("SPARK-23594 GetExternalRowField should support interpreted execution") {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")