[SPARK-23589][SQL] ExternalMapToCatalyst should support interpreted execution
## What changes were proposed in this pull request? This pr supported interpreted mode for `ExternalMapToCatalyst`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #20980 from maropu/SPARK-23589.
This commit is contained in:
parent
d87d30e4fe
commit
afbdf42730
|
@ -1255,8 +1255,64 @@ case class ExternalMapToCatalyst private(
|
|||
override def dataType: MapType = MapType(
|
||||
keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable)
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
|
||||
private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match {
|
||||
case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
|
||||
(input: Any) => {
|
||||
val data = input.asInstanceOf[java.util.Map[Any, Any]]
|
||||
val keys = new Array[Any](data.size)
|
||||
val values = new Array[Any](data.size)
|
||||
val iter = data.entrySet().iterator()
|
||||
var i = 0
|
||||
while (iter.hasNext) {
|
||||
val entry = iter.next()
|
||||
val (key, value) = (entry.getKey, entry.getValue)
|
||||
keys(i) = if (key != null) {
|
||||
keyConverter.eval(InternalRow.fromSeq(key :: Nil))
|
||||
} else {
|
||||
throw new RuntimeException("Cannot use null as map key!")
|
||||
}
|
||||
values(i) = if (value != null) {
|
||||
valueConverter.eval(InternalRow.fromSeq(value :: Nil))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(keys, values)
|
||||
}
|
||||
|
||||
case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
|
||||
(input: Any) => {
|
||||
val data = input.asInstanceOf[scala.collection.Map[Any, Any]]
|
||||
val keys = new Array[Any](data.size)
|
||||
val values = new Array[Any](data.size)
|
||||
var i = 0
|
||||
for ((key, value) <- data) {
|
||||
keys(i) = if (key != null) {
|
||||
keyConverter.eval(InternalRow.fromSeq(key :: Nil))
|
||||
} else {
|
||||
throw new RuntimeException("Cannot use null as map key!")
|
||||
}
|
||||
values(i) = if (value != null) {
|
||||
valueConverter.eval(InternalRow.fromSeq(value :: Nil))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
(keys, values)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val result = child.eval(input)
|
||||
if (result != null) {
|
||||
val (keys, values) = mapCatalystConverter(result)
|
||||
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val inputMap = child.genCode(ctx)
|
||||
|
|
|
@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp}
|
|||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
|
||||
import org.apache.spark.sql.{RandomDataGenerator, Row}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer}
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
|
@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
InternalRow.fromSeq(Seq(Row(1))),
|
||||
"java.lang.Integer is not a valid external type for schema of double")
|
||||
}
|
||||
|
||||
private def javaMapSerializerFor(
|
||||
keyClazz: Class[_],
|
||||
valueClazz: Class[_])(inputObject: Expression): Expression = {
|
||||
|
||||
def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match {
|
||||
case c if c == classOf[java.lang.Integer] =>
|
||||
Invoke(inputObject, "intValue", IntegerType)
|
||||
case c if c == classOf[java.lang.String] =>
|
||||
StaticInvoke(
|
||||
classOf[UTF8String],
|
||||
StringType,
|
||||
"fromString",
|
||||
inputObject :: Nil,
|
||||
returnNullable = false)
|
||||
}
|
||||
|
||||
ExternalMapToCatalyst(
|
||||
inputObject,
|
||||
ObjectType(keyClazz),
|
||||
kvSerializerFor(_, keyClazz),
|
||||
keyNullable = true,
|
||||
ObjectType(valueClazz),
|
||||
kvSerializerFor(_, valueClazz),
|
||||
valueNullable = true
|
||||
)
|
||||
}
|
||||
|
||||
private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = {
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection._
|
||||
|
||||
val curId = new java.util.concurrent.atomic.AtomicInteger()
|
||||
|
||||
def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression =
|
||||
localTypeOf[V].dealias match {
|
||||
case t if t <:< localTypeOf[java.lang.Integer] =>
|
||||
Invoke(inputObject, "intValue", IntegerType)
|
||||
case t if t <:< localTypeOf[String] =>
|
||||
StaticInvoke(
|
||||
classOf[UTF8String],
|
||||
StringType,
|
||||
"fromString",
|
||||
inputObject :: Nil,
|
||||
returnNullable = false)
|
||||
case _ =>
|
||||
inputObject
|
||||
}
|
||||
|
||||
ExternalMapToCatalyst(
|
||||
inputObject,
|
||||
dataTypeFor[T],
|
||||
kvSerializerFor[T],
|
||||
keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive,
|
||||
dataTypeFor[U],
|
||||
kvSerializerFor[U],
|
||||
valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive
|
||||
)
|
||||
}
|
||||
|
||||
test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") {
|
||||
// Simple test
|
||||
val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3")
|
||||
val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() {
|
||||
{
|
||||
put(0, "v0")
|
||||
put(1, "v1")
|
||||
put(2, null)
|
||||
put(3, "v3")
|
||||
}
|
||||
}
|
||||
val expected = CatalystTypeConverters.convertToCatalyst(scalaMap)
|
||||
|
||||
// Java Map
|
||||
val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])(
|
||||
Literal.fromObject(javaMap))
|
||||
checkEvaluation(serializer1, expected)
|
||||
|
||||
// Scala Map
|
||||
val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap))
|
||||
checkEvaluation(serializer2, expected)
|
||||
|
||||
// NULL key test
|
||||
val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String](
|
||||
null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1")
|
||||
val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() {
|
||||
{
|
||||
put(null, "v0")
|
||||
put(1, "v1")
|
||||
}
|
||||
}
|
||||
|
||||
// Java Map
|
||||
val serializer3 =
|
||||
javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])(
|
||||
Literal.fromObject(javaMapHasNullKey))
|
||||
checkExceptionInExpression[RuntimeException](
|
||||
serializer3, EmptyRow, "Cannot use null as map key!")
|
||||
|
||||
// Scala Map
|
||||
val serializer4 = scalaMapSerializerFor[java.lang.Integer, String](
|
||||
Literal.fromObject(scalaMapHasNullKey))
|
||||
|
||||
checkExceptionInExpression[RuntimeException](
|
||||
serializer4, EmptyRow, "Cannot use null as map key!")
|
||||
}
|
||||
}
|
||||
|
||||
class TestBean extends Serializable {
|
||||
|
|
Loading…
Reference in a new issue