[SPARK-36673][SQL] Fix incorrect schema of nested types of union
### What changes were proposed in this pull request?
This patch proposes to fix incorrect schema of `union`.
### Why are the changes needed?
The current `union` result of nested struct columns is incorrect. By definition of `union` API, it should resolve columns by position, not by name. Right now when determining the `output` (aka. the schema) of union plan, we use `merge` API which actually merges two structs (simply think it as concatenate fields from two structs if not overlapping). The merging behavior doesn't match the `union` definition.
So currently we get incorrect schema but the query result is correct. We should fix the incorrect schema.
### Does this PR introduce _any_ user-facing change?
Yes, fixing a bug of incorrect schema.
### How was this patch tested?
Added unit test.
Closes #34025 from viirya/SPARK-36673.
Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit cdd7ae937d
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
af7dd18a5e
commit
895218996a
|
@ -307,7 +307,7 @@ case class Union(
|
|||
children.map(_.output).transpose.map { attrs =>
|
||||
val firstAttr = attrs.head
|
||||
val nullable = attrs.exists(_.nullable)
|
||||
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
|
||||
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
|
||||
if (firstAttr.dataType == newDt) {
|
||||
firstAttr.withNullability(nullable)
|
||||
} else {
|
||||
|
|
|
@ -559,52 +559,81 @@ object StructType extends AbstractDataType {
|
|||
case _ => dt
|
||||
}
|
||||
|
||||
/**
|
||||
* This leverages `merge` to merge data types for UNION operator by specializing
|
||||
* the handling of struct types to follow UNION semantics.
|
||||
*/
|
||||
private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType =
|
||||
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
|
||||
val leftFields = s1.fields
|
||||
val rightFields = s2.fields
|
||||
require(leftFields.size == rightFields.size, "To merge nullability, " +
|
||||
"two structs must have same number of fields.")
|
||||
|
||||
val newFields = leftFields.zip(rightFields).map {
|
||||
case (leftField, rightField) =>
|
||||
leftField.copy(
|
||||
dataType = unionLikeMerge(leftField.dataType, rightField.dataType),
|
||||
nullable = leftField.nullable || rightField.nullable)
|
||||
}.toSeq
|
||||
StructType(newFields)
|
||||
})
|
||||
|
||||
private[sql] def merge(left: DataType, right: DataType): DataType =
|
||||
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
|
||||
val leftFields = s1.fields
|
||||
val rightFields = s2.fields
|
||||
val newFields = mutable.ArrayBuffer.empty[StructField]
|
||||
|
||||
val rightMapped = fieldsMap(rightFields)
|
||||
leftFields.foreach {
|
||||
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
|
||||
rightMapped.get(leftName)
|
||||
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
|
||||
try {
|
||||
leftField.copy(
|
||||
dataType = merge(leftType, rightType),
|
||||
nullable = leftNullable || rightNullable)
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
|
||||
}
|
||||
}
|
||||
.orElse {
|
||||
Some(leftField)
|
||||
}
|
||||
.foreach(newFields += _)
|
||||
}
|
||||
|
||||
val leftMapped = fieldsMap(leftFields)
|
||||
rightFields
|
||||
.filterNot(f => leftMapped.get(f.name).nonEmpty)
|
||||
.foreach { f =>
|
||||
newFields += f
|
||||
}
|
||||
|
||||
StructType(newFields.toSeq)
|
||||
})
|
||||
|
||||
private def mergeInternal(
|
||||
left: DataType,
|
||||
right: DataType,
|
||||
mergeStruct: (StructType, StructType) => StructType): DataType =
|
||||
(left, right) match {
|
||||
case (ArrayType(leftElementType, leftContainsNull),
|
||||
ArrayType(rightElementType, rightContainsNull)) =>
|
||||
ArrayType(
|
||||
merge(leftElementType, rightElementType),
|
||||
mergeInternal(leftElementType, rightElementType, mergeStruct),
|
||||
leftContainsNull || rightContainsNull)
|
||||
|
||||
case (MapType(leftKeyType, leftValueType, leftContainsNull),
|
||||
MapType(rightKeyType, rightValueType, rightContainsNull)) =>
|
||||
MapType(
|
||||
merge(leftKeyType, rightKeyType),
|
||||
merge(leftValueType, rightValueType),
|
||||
mergeInternal(leftKeyType, rightKeyType, mergeStruct),
|
||||
mergeInternal(leftValueType, rightValueType, mergeStruct),
|
||||
leftContainsNull || rightContainsNull)
|
||||
|
||||
case (StructType(leftFields), StructType(rightFields)) =>
|
||||
val newFields = mutable.ArrayBuffer.empty[StructField]
|
||||
|
||||
val rightMapped = fieldsMap(rightFields)
|
||||
leftFields.foreach {
|
||||
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
|
||||
rightMapped.get(leftName)
|
||||
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
|
||||
try {
|
||||
leftField.copy(
|
||||
dataType = merge(leftType, rightType),
|
||||
nullable = leftNullable || rightNullable)
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
|
||||
}
|
||||
}
|
||||
.orElse {
|
||||
Some(leftField)
|
||||
}
|
||||
.foreach(newFields += _)
|
||||
}
|
||||
|
||||
val leftMapped = fieldsMap(leftFields)
|
||||
rightFields
|
||||
.filterNot(f => leftMapped.get(f.name).nonEmpty)
|
||||
.foreach { f =>
|
||||
newFields += f
|
||||
}
|
||||
|
||||
StructType(newFields.toSeq)
|
||||
case (s1: StructType, s2: StructType) => mergeStruct(s1, s2)
|
||||
|
||||
case (DecimalType.Fixed(leftPrecision, leftScale),
|
||||
DecimalType.Fixed(rightPrecision, rightScale)) =>
|
||||
|
|
|
@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
|
|||
children.map(_.output).transpose.map { attrs =>
|
||||
val firstAttr = attrs.head
|
||||
val nullable = attrs.exists(_.nullable)
|
||||
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
|
||||
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
|
||||
if (firstAttr.dataType == newDt) {
|
||||
firstAttr.withNullability(nullable)
|
||||
} else {
|
||||
|
|
|
@ -1018,6 +1018,64 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
|
|||
unionDF = df1.unionByName(df2)
|
||||
checkAnswer(unionDF, expected)
|
||||
}
|
||||
|
||||
test("SPARK-36673: Only merge nullability for Unions of struct") {
|
||||
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
|
||||
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))
|
||||
|
||||
val union1 = df1.union(df2)
|
||||
val union2 = df1.unionByName(df2)
|
||||
|
||||
val schema = StructType(Seq(StructField("id", LongType, false),
|
||||
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))
|
||||
|
||||
Seq(union1, union2).foreach { df =>
|
||||
assert(df.schema == schema)
|
||||
assert(df.queryExecution.optimizedPlan.schema == schema)
|
||||
assert(df.queryExecution.executedPlan.schema == schema)
|
||||
|
||||
checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
|
||||
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-36673: Only merge nullability for unionByName of struct") {
|
||||
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
|
||||
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))
|
||||
|
||||
val df = df1.unionByName(df2)
|
||||
|
||||
val schema = StructType(Seq(StructField("id", LongType, false),
|
||||
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))
|
||||
|
||||
assert(df.schema == schema)
|
||||
assert(df.queryExecution.optimizedPlan.schema == schema)
|
||||
assert(df.queryExecution.executedPlan.schema == schema)
|
||||
|
||||
checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
|
||||
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-36673: Union of structs with different orders") {
|
||||
val df1 = spark.range(2).withColumn("nested",
|
||||
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))
|
||||
val df2 = spark.range(2).withColumn("nested",
|
||||
struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))
|
||||
|
||||
val err1 = intercept[AnalysisException](df1.union(df2).collect())
|
||||
|
||||
assert(err1.message
|
||||
.contains("Union can only be performed on tables with the compatible column types"))
|
||||
|
||||
val df3 = spark.range(2).withColumn("nested",
|
||||
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string"))))
|
||||
val df4 = spark.range(2).withColumn("nested",
|
||||
struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1"))))
|
||||
|
||||
val err2 = intercept[AnalysisException](df3.union(df4).collect())
|
||||
assert(err2.message
|
||||
.contains("Union can only be performed on tables with the compatible column types"))
|
||||
}
|
||||
}
|
||||
|
||||
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
|
||||
|
|
Loading…
Reference in a new issue