[SPARK-25772][SQL] Fix java map of structs deserialization
This is a follow-up PR for #22708. It considers another case of java beans deserialization: java maps with struct keys/values. When deserializing values of MapType with struct keys/values in java beans, fields of structs get mixed up. I suggest using struct data types retrieved from resolved input data instead of inferring them from java beans. ## What changes were proposed in this pull request? Invocations of "keyArray" and "valueArray" functions are used to extract arrays of keys and values. Struct type of keys or values is also inferred from java bean structure and ends up with mixed up field order. I created a new UnresolvedInvoke expression as a temporary substitution of Invoke expression while no actual data is available. It allows to provide the resulting data type during analysis based on the resolved input data, not on the java bean (similar to UnresolvedMapObjects). Key and value arrays are then fed to MapObjects expression which I replaced with UnresolvedMapObjects, just like in case of ArrayType. Finally I added resolution of UnresolvedInvoke expressions in Analyzer.resolveExpression method as an additional pattern matching case. ## How was this patch tested? Added a test case. Built complete project on travis. viirya kiszk cloud-fan michalsenkyr marmbrus liancheng Closes #22745 from vofque/SPARK-21402-FOLLOWUP. Lead-authored-by: Vladimir Kuriatkov <vofque@gmail.com> Co-authored-by: Vladimir Kuriatkov <Vladimir_Kuriatkov@epam.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
4506dad8a9
commit
584e767d37
|
@ -278,24 +278,20 @@ object JavaTypeInference {
|
|||
|
||||
case _ if mapType.isAssignableFrom(typeToken) =>
|
||||
val (keyType, valueType) = mapKeyValueType(typeToken)
|
||||
val keyDataType = inferDataType(keyType)._1
|
||||
val valueDataType = inferDataType(valueType)._1
|
||||
|
||||
val keyData =
|
||||
Invoke(
|
||||
MapObjects(
|
||||
UnresolvedMapObjects(
|
||||
p => deserializerFor(keyType, Some(p)),
|
||||
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
|
||||
keyDataType),
|
||||
GetKeyArrayFromMap(getPath)),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
val valueData =
|
||||
Invoke(
|
||||
MapObjects(
|
||||
UnresolvedMapObjects(
|
||||
p => deserializerFor(valueType, Some(p)),
|
||||
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
|
||||
valueDataType),
|
||||
GetValueArrayFromMap(getPath)),
|
||||
"array",
|
||||
ObjectType(classOf[Array[Any]]))
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.spark.serializer._
|
|||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
|
@ -1787,3 +1788,78 @@ case class ValidateExternalType(child: Expression, expected: DataType)
|
|||
ev.copy(code = code, isNull = input.isNull)
|
||||
}
|
||||
}
|
||||
|
||||
object GetKeyArrayFromMap {
|
||||
|
||||
/**
|
||||
* Construct an instance of GetArrayFromMap case class
|
||||
* extracting a key array from a Map expression.
|
||||
*
|
||||
* @param child a Map expression to extract a key array from
|
||||
*/
|
||||
def apply(child: Expression): Expression = {
|
||||
GetArrayFromMap(
|
||||
child,
|
||||
"keyArray",
|
||||
_.keyArray(),
|
||||
{ case MapType(kt, _, _) => kt })
|
||||
}
|
||||
}
|
||||
|
||||
object GetValueArrayFromMap {
|
||||
|
||||
/**
|
||||
* Construct an instance of GetArrayFromMap case class
|
||||
* extracting a value array from a Map expression.
|
||||
*
|
||||
* @param child a Map expression to extract a value array from
|
||||
*/
|
||||
def apply(child: Expression): Expression = {
|
||||
GetArrayFromMap(
|
||||
child,
|
||||
"valueArray",
|
||||
_.valueArray(),
|
||||
{ case MapType(_, vt, _) => vt })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a key/value array from a Map expression.
|
||||
*
|
||||
* @param child a Map expression to extract an array from
|
||||
* @param functionName name of the function that is invoked to extract an array
|
||||
* @param arrayGetter function extracting `ArrayData` from `MapData`
|
||||
* @param elementTypeGetter function extracting array element `DataType` from `MapType`
|
||||
*/
|
||||
case class GetArrayFromMap private(
|
||||
child: Expression,
|
||||
functionName: String,
|
||||
arrayGetter: MapData => ArrayData,
|
||||
elementTypeGetter: MapType => DataType) extends UnaryExpression with NonSQLExpression {
|
||||
|
||||
private lazy val encodedFunctionName: String = TermName(functionName).encodedName.toString
|
||||
|
||||
lazy val dataType: DataType = {
|
||||
val mt: MapType = child.dataType.asInstanceOf[MapType]
|
||||
ArrayType(elementTypeGetter(mt))
|
||||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (child.dataType.isInstanceOf[MapType]) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"Can't extract array from $child: need map type but got ${child.dataType.catalogString}")
|
||||
}
|
||||
}
|
||||
|
||||
override def nullSafeEval(input: Any): Any = {
|
||||
arrayGetter(input.asInstanceOf[MapData])
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, childValue => s"$childValue.$encodedFunctionName()")
|
||||
}
|
||||
|
||||
override def toString: String = s"$child.$functionName"
|
||||
}
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
import org.junit.*;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoder;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.test.TestSparkSession;
|
||||
|
||||
public class JavaBeanDeserializationSuite implements Serializable {
|
||||
|
||||
private TestSparkSession spark;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
spark = new TestSparkSession();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
spark.stop();
|
||||
spark = null;
|
||||
}
|
||||
|
||||
private static final List<ArrayRecord> ARRAY_RECORDS = new ArrayList<>();
|
||||
|
||||
static {
|
||||
ARRAY_RECORDS.add(
|
||||
new ArrayRecord(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
|
||||
);
|
||||
ARRAY_RECORDS.add(
|
||||
new ArrayRecord(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
|
||||
);
|
||||
ARRAY_RECORDS.add(
|
||||
new ArrayRecord(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBeanWithArrayFieldDeserialization() {
|
||||
|
||||
Encoder<ArrayRecord> encoder = Encoders.bean(ArrayRecord.class);
|
||||
|
||||
Dataset<ArrayRecord> dataset = spark
|
||||
.read()
|
||||
.format("json")
|
||||
.schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>")
|
||||
.load("src/test/resources/test-data/with-array-fields.json")
|
||||
.as(encoder);
|
||||
|
||||
List<ArrayRecord> records = dataset.collectAsList();
|
||||
Assert.assertEquals(records, ARRAY_RECORDS);
|
||||
}
|
||||
|
||||
private static final List<MapRecord> MAP_RECORDS = new ArrayList<>();
|
||||
|
||||
static {
|
||||
MAP_RECORDS.add(new MapRecord(1,
|
||||
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(111, 211), new Interval(121, 221)))
|
||||
));
|
||||
MAP_RECORDS.add(new MapRecord(2,
|
||||
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(112, 212), new Interval(122, 222)))
|
||||
));
|
||||
MAP_RECORDS.add(new MapRecord(3,
|
||||
toMap(Arrays.asList("a", "b"), Arrays.asList(new Interval(113, 213), new Interval(123, 223)))
|
||||
));
|
||||
MAP_RECORDS.add(new MapRecord(4, new HashMap<>()));
|
||||
MAP_RECORDS.add(new MapRecord(5, null));
|
||||
}
|
||||
|
||||
private static <K, V> Map<K, V> toMap(Collection<K> keys, Collection<V> values) {
|
||||
Map<K, V> map = new HashMap<>();
|
||||
Iterator<K> keyI = keys.iterator();
|
||||
Iterator<V> valueI = values.iterator();
|
||||
while (keyI.hasNext() && valueI.hasNext()) {
|
||||
map.put(keyI.next(), valueI.next());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBeanWithMapFieldsDeserialization() {
|
||||
|
||||
Encoder<MapRecord> encoder = Encoders.bean(MapRecord.class);
|
||||
|
||||
Dataset<MapRecord> dataset = spark
|
||||
.read()
|
||||
.format("json")
|
||||
.schema("id int, intervals map<string, struct<startTime: bigint, endTime: bigint>>")
|
||||
.load("src/test/resources/test-data/with-map-fields.json")
|
||||
.as(encoder);
|
||||
|
||||
List<MapRecord> records = dataset.collectAsList();
|
||||
|
||||
Assert.assertEquals(records, MAP_RECORDS);
|
||||
}
|
||||
|
||||
public static class ArrayRecord {
|
||||
|
||||
private int id;
|
||||
private List<Interval> intervals;
|
||||
|
||||
public ArrayRecord() { }
|
||||
|
||||
ArrayRecord(int id, List<Interval> intervals) {
|
||||
this.id = id;
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(int id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public List<Interval> getIntervals() {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
public void setIntervals(List<Interval> intervals) {
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (!(obj instanceof ArrayRecord)) return false;
|
||||
ArrayRecord other = (ArrayRecord) obj;
|
||||
return (other.id == this.id) && other.intervals.equals(this.intervals);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("{ id: %d, intervals: %s }", id, intervals);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MapRecord {
|
||||
|
||||
private int id;
|
||||
private Map<String, Interval> intervals;
|
||||
|
||||
public MapRecord() { }
|
||||
|
||||
MapRecord(int id, Map<String, Interval> intervals) {
|
||||
this.id = id;
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(int id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public Map<String, Interval> getIntervals() {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
public void setIntervals(Map<String, Interval> intervals) {
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (!(obj instanceof MapRecord)) return false;
|
||||
MapRecord other = (MapRecord) obj;
|
||||
return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("{ id: %d, intervals: %s }", id, intervals);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Interval {
|
||||
|
||||
private long startTime;
|
||||
private long endTime;
|
||||
|
||||
public Interval() { }
|
||||
|
||||
Interval(long startTime, long endTime) {
|
||||
this.startTime = startTime;
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
public long getStartTime() {
|
||||
return startTime;
|
||||
}
|
||||
|
||||
public void setStartTime(long startTime) {
|
||||
this.startTime = startTime;
|
||||
}
|
||||
|
||||
public long getEndTime() {
|
||||
return endTime;
|
||||
}
|
||||
|
||||
public void setEndTime(long endTime) {
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (!(obj instanceof Interval)) return false;
|
||||
Interval other = (Interval) obj;
|
||||
return (other.startTime == this.startTime) && (other.endTime == this.endTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("[%d,%d]", startTime, endTime);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,154 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoder;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.test.TestSparkSession;
|
||||
|
||||
public class JavaBeanWithArraySuite {
|
||||
|
||||
private static final List<Record> RECORDS = new ArrayList<>();
|
||||
|
||||
static {
|
||||
RECORDS.add(new Record(1, Arrays.asList(new Interval(111, 211), new Interval(121, 221))));
|
||||
RECORDS.add(new Record(2, Arrays.asList(new Interval(112, 212), new Interval(122, 222))));
|
||||
RECORDS.add(new Record(3, Arrays.asList(new Interval(113, 213), new Interval(123, 223))));
|
||||
}
|
||||
|
||||
private TestSparkSession spark;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
spark = new TestSparkSession();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
spark.stop();
|
||||
spark = null;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBeanWithArrayFieldDeserialization() {
|
||||
|
||||
Encoder<Record> encoder = Encoders.bean(Record.class);
|
||||
|
||||
Dataset<Record> dataset = spark
|
||||
.read()
|
||||
.format("json")
|
||||
.schema("id int, intervals array<struct<startTime: bigint, endTime: bigint>>")
|
||||
.load("src/test/resources/test-data/with-array-fields.json")
|
||||
.as(encoder);
|
||||
|
||||
List<Record> records = dataset.collectAsList();
|
||||
Assert.assertEquals(records, RECORDS);
|
||||
}
|
||||
|
||||
public static class Record {
|
||||
|
||||
private int id;
|
||||
private List<Interval> intervals;
|
||||
|
||||
public Record() { }
|
||||
|
||||
Record(int id, List<Interval> intervals) {
|
||||
this.id = id;
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
public int getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(int id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public List<Interval> getIntervals() {
|
||||
return intervals;
|
||||
}
|
||||
|
||||
public void setIntervals(List<Interval> intervals) {
|
||||
this.intervals = intervals;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (!(obj instanceof Record)) return false;
|
||||
Record other = (Record) obj;
|
||||
return (other.id == this.id) && other.intervals.equals(this.intervals);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("{ id: %d, intervals: %s }", id, intervals);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Interval {
|
||||
|
||||
private long startTime;
|
||||
private long endTime;
|
||||
|
||||
public Interval() { }
|
||||
|
||||
Interval(long startTime, long endTime) {
|
||||
this.startTime = startTime;
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
public long getStartTime() {
|
||||
return startTime;
|
||||
}
|
||||
|
||||
public void setStartTime(long startTime) {
|
||||
this.startTime = startTime;
|
||||
}
|
||||
|
||||
public long getEndTime() {
|
||||
return endTime;
|
||||
}
|
||||
|
||||
public void setEndTime(long endTime) {
|
||||
this.endTime = endTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (!(obj instanceof Interval)) return false;
|
||||
Interval other = (Interval) obj;
|
||||
return (other.startTime == this.startTime) && (other.endTime == this.endTime);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("[%d,%d]", startTime, endTime);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
{ "id": 1, "intervals": { "a": { "startTime": 111, "endTime": 211 }, "b": { "startTime": 121, "endTime": 221 }}}
|
||||
{ "id": 2, "intervals": { "a": { "startTime": 112, "endTime": 212 }, "b": { "startTime": 122, "endTime": 222 }}}
|
||||
{ "id": 3, "intervals": { "a": { "startTime": 113, "endTime": 213 }, "b": { "startTime": 123, "endTime": 223 }}}
|
||||
{ "id": 4, "intervals": { }}
|
||||
{ "id": 5 }
|
Loading…
Reference in a new issue