[SPARK-23583][SQL] Invoke should support interpreted execution
## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #20797 from kiszk/SPARK-28583.
This commit is contained in:
parent
5197562afe
commit
a35523653c
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
|
|||
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.objects._
|
||||
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
|
@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection {
|
|||
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
|
||||
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
|
||||
"throws", "transient", "true", "try", "void", "volatile", "while")
|
||||
|
||||
val typeJavaMapping = Map[DataType, Class[_]](
|
||||
BooleanType -> classOf[Boolean],
|
||||
ByteType -> classOf[Byte],
|
||||
ShortType -> classOf[Short],
|
||||
IntegerType -> classOf[Int],
|
||||
LongType -> classOf[Long],
|
||||
FloatType -> classOf[Float],
|
||||
DoubleType -> classOf[Double],
|
||||
StringType -> classOf[UTF8String],
|
||||
DateType -> classOf[DateType.InternalType],
|
||||
TimestampType -> classOf[TimestampType.InternalType],
|
||||
BinaryType -> classOf[BinaryType.InternalType],
|
||||
CalendarIntervalType -> classOf[CalendarInterval]
|
||||
)
|
||||
|
||||
val typeBoxedJavaMapping = Map[DataType, Class[_]](
|
||||
BooleanType -> classOf[java.lang.Boolean],
|
||||
ByteType -> classOf[java.lang.Byte],
|
||||
ShortType -> classOf[java.lang.Short],
|
||||
IntegerType -> classOf[java.lang.Integer],
|
||||
LongType -> classOf[java.lang.Long],
|
||||
FloatType -> classOf[java.lang.Float],
|
||||
DoubleType -> classOf[java.lang.Double],
|
||||
DateType -> classOf[java.lang.Integer],
|
||||
TimestampType -> classOf[java.lang.Long]
|
||||
)
|
||||
|
||||
def dataTypeJavaClass(dt: DataType): Class[_] = {
|
||||
dt match {
|
||||
case _: DecimalType => classOf[Decimal]
|
||||
case _: StructType => classOf[InternalRow]
|
||||
case _: ArrayType => classOf[ArrayData]
|
||||
case _: MapType => classOf[MapData]
|
||||
case ObjectType(cls) => cls
|
||||
case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
|
||||
}
|
||||
}
|
||||
|
||||
def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
|
||||
if (arguments != Nil) {
|
||||
arguments.map(e => dataTypeJavaClass(e.dataType))
|
||||
} else {
|
||||
Seq.empty
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions.objects
|
||||
|
||||
import java.lang.reflect.Modifier
|
||||
import java.lang.reflect.{Method, Modifier}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.Builder
|
||||
|
@ -28,7 +28,7 @@ import scala.util.Try
|
|||
import org.apache.spark.{SparkConf, SparkEnv}
|
||||
import org.apache.spark.serializer._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression {
|
|||
|
||||
(argCode, argValues.mkString(", "), resultIsNull)
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate each argument with a given row, invoke a method with a given object and arguments,
|
||||
* and cast a return value if the return type can be mapped to a Java Boxed type
|
||||
*
|
||||
* @param obj the object for the method to be called. If null, perform s static method call
|
||||
* @param method the method object to be called
|
||||
* @param arguments the arguments used for the method call
|
||||
* @param input the row used for evaluating arguments
|
||||
* @param dataType the data type of the return object
|
||||
* @return the return object of a method call
|
||||
*/
|
||||
def invoke(
|
||||
obj: Any,
|
||||
method: Method,
|
||||
arguments: Seq[Expression],
|
||||
input: InternalRow,
|
||||
dataType: DataType): Any = {
|
||||
val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
|
||||
if (needNullCheck && args.exists(_ == null)) {
|
||||
// return null if one of arguments is null
|
||||
null
|
||||
} else {
|
||||
val ret = method.invoke(obj, args: _*)
|
||||
val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
|
||||
if (boxedClass.isDefined) {
|
||||
boxedClass.get.cast(ret)
|
||||
} else {
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -264,12 +296,11 @@ case class Invoke(
|
|||
propagateNull: Boolean = true,
|
||||
returnNullable : Boolean = true) extends InvokeLike {
|
||||
|
||||
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
|
||||
|
||||
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
|
||||
override def children: Seq[Expression] = targetObject +: arguments
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
|
||||
|
||||
@transient lazy val method = targetObject.dataType match {
|
||||
|
@ -283,6 +314,21 @@ case class Invoke(
|
|||
case _ => None
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val obj = targetObject.eval(input)
|
||||
if (obj == null) {
|
||||
// return null if obj is null
|
||||
null
|
||||
} else {
|
||||
val invokeMethod = if (method.isDefined) {
|
||||
method.get
|
||||
} else {
|
||||
obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
|
||||
}
|
||||
invoke(obj, invokeMethod, arguments, input, dataType)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val obj = targetObject.genCode(ctx)
|
||||
|
|
|
@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
|
|||
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
|
||||
import org.apache.spark.sql.catalyst.expressions.objects._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class InvokeTargetClass extends Serializable {
|
||||
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
|
||||
def filterPrimitiveInt(e: Int): Boolean = e > 0
|
||||
def binOp(e1: Int, e2: Double): Double = e1 + e2
|
||||
}
|
||||
|
||||
class InvokeTargetSubClass extends InvokeTargetClass {
|
||||
override def binOp(e1: Int, e2: Double): Double = e1 - e2
|
||||
}
|
||||
|
||||
class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
||||
|
@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
|
||||
}
|
||||
|
||||
test("SPARK-23583: Invoke should support interpreted execution") {
|
||||
val targetObject = new InvokeTargetClass
|
||||
val funcClass = classOf[InvokeTargetClass]
|
||||
val funcObj = Literal.create(targetObject, ObjectType(funcClass))
|
||||
val targetSubObject = new InvokeTargetSubClass
|
||||
val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
|
||||
val funcNullObj = Literal.create(null, ObjectType(funcClass))
|
||||
|
||||
val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
|
||||
val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
|
||||
val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
|
||||
java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
|
||||
false, InternalRow.fromSeq(Seq(-1)))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
|
||||
null, InternalRow.fromSeq(Seq(null)))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
|
||||
null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))
|
||||
|
||||
checkObjectExprEvaluation(
|
||||
Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
|
||||
}
|
||||
|
||||
test("SPARK-23585: UnwrapOption should support interpreted execution") {
|
||||
val cls = classOf[Option[Int]]
|
||||
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
|
||||
|
@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
|
||||
}
|
||||
|
||||
// by scala values instead of catalyst values.
|
||||
private def checkObjectExprEvaluation(
|
||||
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
|
||||
val serializer = new JavaSerializer(new SparkConf()).newInstance
|
||||
val resolver = ResolveTimeZone(new SQLConf)
|
||||
val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
|
||||
checkEvaluationWithoutCodegen(expr, expected, inputRow)
|
||||
checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
|
||||
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
|
||||
checkEvaluationWithUnsafeProjection(
|
||||
expr,
|
||||
expected,
|
||||
inputRow,
|
||||
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
|
||||
}
|
||||
checkEvaluationWithOptimization(expr, expected, inputRow)
|
||||
}
|
||||
|
||||
test("SPARK-23594 GetExternalRowField should support interpreted execution") {
|
||||
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
|
||||
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
|
||||
|
|
Loading…
Reference in a new issue