[SPARK-25772][SQL] Fix java map of structs deserialization

This is a follow-up PR for #22708. It considers another case of java beans deserialization: java maps with struct keys/values.

When deserializing values of MapType with struct keys/values in java beans, fields of structs get mixed up. I suggest using struct data types retrieved from resolved input data instead of inferring them from java beans.

## What changes were proposed in this pull request?

Invocations of "keyArray" and "valueArray" functions are used to extract arrays of keys and values. Struct type of keys or values is also inferred from java bean structure and ends up with mixed up field order.
I created a new UnresolvedInvoke expression as a temporary substitution of Invoke expression while no actual data is available. It allows to provide the resulting data type during analysis based on the resolved input data, not on the java bean (similar to UnresolvedMapObjects).

Key and value arrays are then fed to MapObjects expression which I replaced with UnresolvedMapObjects, just like in case of ArrayType.

Finally I added resolution of UnresolvedInvoke expressions in Analyzer.resolveExpression method as an additional pattern matching case.

## How was this patch tested?

Added a test case.
Built complete project on travis.

viirya kiszk cloud-fan michalsenkyr marmbrus liancheng

Closes #22745 from vofque/SPARK-21402-FOLLOWUP.

Lead-authored-by: Vladimir Kuriatkov <vofque@gmail.com>
Co-authored-by: Vladimir Kuriatkov <Vladimir_Kuriatkov@epam.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Vladimir Kuriatkov 2018-10-24 09:29:40 +08:00 committed by Wenchen Fan
parent 4506dad8a9
commit 584e767d37
5 changed files with 325 additions and 162 deletions

View file

@ -278,24 +278,20 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) => case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken) val (keyType, valueType) = mapKeyValueType(typeToken)
val keyDataType = inferDataType(keyType)._1
val valueDataType = inferDataType(valueType)._1
val keyData = val keyData =
Invoke( Invoke(
MapObjects( UnresolvedMapObjects(
p => deserializerFor(keyType, Some(p)), p => deserializerFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)), GetKeyArrayFromMap(getPath)),
keyDataType),
"array", "array",
ObjectType(classOf[Array[Any]])) ObjectType(classOf[Array[Any]]))
val valueData = val valueData =
Invoke( Invoke(
MapObjects( UnresolvedMapObjects(
p => deserializerFor(valueType, Some(p)), p => deserializerFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)), GetValueArrayFromMap(getPath)),
valueDataType),
"array", "array",
ObjectType(classOf[Array[Any]])) ObjectType(classOf[Array[Any]]))

View file

@ -30,6 +30,7 @@ import org.apache.spark.serializer._
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
@ -1787,3 +1788,78 @@ case class ValidateExternalType(child: Expression, expected: DataType)
ev.copy(code = code, isNull = input.isNull) ev.copy(code = code, isNull = input.isNull)
} }
} }
object GetKeyArrayFromMap {
/**
* Construct an instance of GetArrayFromMap case class
* extracting a key array from a Map expression.
*
* @param child a Map expression to extract a key array from
*/
def apply(child: Expression): Expression = {
GetArrayFromMap(
child,
"keyArray",
_.keyArray(),
{ case MapType(kt, _, _) => kt })
}
}
object GetValueArrayFromMap {
/**
* Construct an instance of GetArrayFromMap case class
* extracting a value array from a Map expression.
*
* @param child a Map expression to extract a value array from
*/
def apply(child: Expression): Expression = {
GetArrayFromMap(
child,
"valueArray",
_.valueArray(),
{ case MapType(_, vt, _) => vt })
}
}
/**
* Extracts a key/value array from a Map expression.
*
* @param child a Map expression to extract an array from
* @param functionName name of the function that is invoked to extract an array
* @param arrayGetter function extracting `ArrayData` from `MapData`
* @param elementTypeGetter function extracting array element `DataType` from `MapType`
*/
case class GetArrayFromMap private(
child: Expression,
functionName: String,
arrayGetter: MapData => ArrayData,
elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression {
private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString
lazy val dataType: DataType = {
val mt: MapType = child.dataType.asInstanceOf[MapType]
ArrayType(elementTypeGetter(mt))
}
override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[MapType]) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}")
}
}
override def nullSafeEval(input: Any): Any = {
arrayGetter(input.asInstanceOf[MapData])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()")
}
override def toString: String = s"$child.$functionName"
}

View file

@ -0,0 +1,240 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql;
import java.io.Serializable;
import java.util.*;
import org.junit.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.test.TestSparkSession;
public class JavaBeanDeserializationSuite implements Serializable {
private TestSparkSession spark;
@Before
public void setUp() {
spark = new TestSparkSession();
}
@After
public void tearDown() {
spark.stop();
spark = null;
}
private static final List<ArrayRecord> ARRAY_RECORDS = new ArrayList<>();
static {
ARRAY_RECORDS.add(
new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
);
ARRAY_RECORDS.add(
new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
);
ARRAY_RECORDS.add(
new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
);
}
@Test
public void testBeanWithArrayFieldDeserialization() {
Encoder<ArrayRecord> encoder = Encoders.bean(ArrayRecord.class);
Dataset<ArrayRecord> dataset = spark
.read()
.format("json")
.schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>")
.load("src/test/resources/test-data/with-array-fields.json")
.as(encoder);
List<ArrayRecord> records = dataset.collectAsList();
Assert.assertEquals(records, ARRAY_RECORDS);
}
private static final List<MapRecord> MAP_RECORDS = new ArrayList<>();
static {
MAP_RECORDS.add(new MapRecord(1,
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
));
MAP_RECORDS.add(new MapRecord(2,
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
));
MAP_RECORDS.add(new MapRecord(3,
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
));
MAP_RECORDS.add(new MapRecord(4, new HashMap<>()));
MAP_RECORDS.add(new MapRecord(5, null));
}
private static <K, V> Map<K, V> toMap(Collection<K> keys, Collection<V> values) {
Map<K, V> map = new HashMap<>();
Iterator<K> keyI = keys.iterator();
Iterator<V> valueI = values.iterator();
while (keyI.hasNext() && valueI.hasNext()) {
map.put(keyI.next(), valueI.next());
}
return map;
}
@Test
public void testBeanWithMapFieldsDeserialization() {
Encoder<MapRecord> encoder = Encoders.bean(MapRecord.class);
Dataset<MapRecord> dataset = spark
.read()
.format("json")
.schema("id int, intervals map<string, struct<startTime: bigint, endTime: bigint>>")
.load("src/test/resources/test-data/with-map-fields.json")
.as(encoder);
List<MapRecord> records = dataset.collectAsList();
Assert.assertEquals(records, MAP_RECORDS);
}
public static class ArrayRecord {
private int id;
private List<Interval> intervals;
public ArrayRecord() { }
ArrayRecord(int id, List<Interval> intervals) {
this.id = id;
this.intervals = intervals;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public List<Interval> getIntervals() {
return intervals;
}
public void setIntervals(List<Interval> intervals) {
this.intervals = intervals;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof ArrayRecord)) return false;
ArrayRecord other = (ArrayRecord) obj;
return (other.id == this.id) && other.intervals.equals(this.intervals);
}
@Override
public String toString() {
return String.format("{ id: %d, intervals: %s }", id, intervals);
}
}
public static class MapRecord {
private int id;
private Map<String, Interval> intervals;
public MapRecord() { }
MapRecord(int id, Map<String, Interval> intervals) {
this.id = id;
this.intervals = intervals;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public Map<String, Interval> getIntervals() {
return intervals;
}
public void setIntervals(Map<String, Interval> intervals) {
this.intervals = intervals;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof MapRecord)) return false;
MapRecord other = (MapRecord) obj;
return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
}
@Override
public String toString() {
return String.format("{ id: %d, intervals: %s }", id, intervals);
}
}
public static class Interval {
private long startTime;
private long endTime;
public Interval() { }
Interval(long startTime, long endTime) {
this.startTime = startTime;
this.endTime = endTime;
}
public long getStartTime() {
return startTime;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public long getEndTime() {
return endTime;
}
public void setEndTime(long endTime) {
this.endTime = endTime;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Interval)) return false;
Interval other = (Interval) obj;
return (other.startTime == this.startTime) && (other.endTime == this.endTime);
}
@Override
public String toString() {
return String.format("[%d,%d]", startTime, endTime);
}
}
}

View file

@ -1,154 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.test.TestSparkSession;
public class JavaBeanWithArraySuite {
private static final List<Record> RECORDS = new ArrayList<>();
static {
RECORDS.add(new Record(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221))));
RECORDS.add(new Record(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222))));
RECORDS.add(new Record(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223))));
}
private TestSparkSession spark;
@Before
public void setUp() {
spark = new TestSparkSession();
}
@After
public void tearDown() {
spark.stop();
spark = null;
}
@Test
public void testBeanWithArrayFieldDeserialization() {
Encoder<Record> encoder = Encoders.bean(Record.class);
Dataset<Record> dataset = spark
.read()
.format("json")
.schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>")
.load("src/test/resources/test-data/with-array-fields.json")
.as(encoder);
List<Record> records = dataset.collectAsList();
Assert.assertEquals(records, RECORDS);
}
public static class Record {
private int id;
private List<Interval> intervals;
public Record() { }
Record(int id, List<Interval> intervals) {
this.id = id;
this.intervals = intervals;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public List<Interval> getIntervals() {
return intervals;
}
public void setIntervals(List<Interval> intervals) {
this.intervals = intervals;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Record)) return false;
Record other = (Record) obj;
return (other.id == this.id) && other.intervals.equals(this.intervals);
}
@Override
public String toString() {
return String.format("{ id: %d, intervals: %s }", id, intervals);
}
}
public static class Interval {
private long startTime;
private long endTime;
public Interval() { }
Interval(long startTime, long endTime) {
this.startTime = startTime;
this.endTime = endTime;
}
public long getStartTime() {
return startTime;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public long getEndTime() {
return endTime;
}
public void setEndTime(long endTime) {
this.endTime = endTime;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Interval)) return false;
Interval other = (Interval) obj;
return (other.startTime == this.startTime) && (other.endTime == this.endTime);
}
@Override
public String toString() {
return String.format("[%d,%d]", startTime, endTime);
}
}
}

View file

@ -0,0 +1,5 @@
{ "id": 1, "intervals": { "a": { "startTime": 111, "endTime": 211 }, "b": { "startTime": 121, "endTime": 221 }}}
{ "id": 2, "intervals": { "a": { "startTime": 112, "endTime": 212 }, "b": { "startTime": 122, "endTime": 222 }}}
{ "id": 3, "intervals": { "a": { "startTime": 113, "endTime": 213 }, "b": { "startTime": 123, "endTime": 223 }}}
{ "id": 4, "intervals": { }}
{ "id": 5 }