[SPARK-36673][SQL] Fix incorrect schema of nested types of union
### What changes were proposed in this pull request?
This patch proposes to fix incorrect schema of `union`.
### Why are the changes needed?
The current `union` result of nested struct columns is incorrect. By definition of `union` API, it should resolve columns by position, not by name. Right now when determining the `output` (aka. the schema) of union plan, we use `merge` API which actually merges two structs (simply think it as concatenate fields from two structs if not overlapping). The merging behavior doesn't match the `union` definition.
So currently we get incorrect schema but the query result is correct. We should fix the incorrect schema.
### Does this PR introduce _any_ user-facing change?
Yes, fixing a bug of incorrect schema.
### How was this patch tested?
Added unit test.
Closes #34025 from viirya/SPARK-36673.
Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit cdd7ae937d
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
af7dd18a5e
commit
895218996a
|
@ -307,7 +307,7 @@ case class Union(
|
||||||
children.map(_.output).transpose.map { attrs =>
|
children.map(_.output).transpose.map { attrs =>
|
||||||
val firstAttr = attrs.head
|
val firstAttr = attrs.head
|
||||||
val nullable = attrs.exists(_.nullable)
|
val nullable = attrs.exists(_.nullable)
|
||||||
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
|
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
|
||||||
if (firstAttr.dataType == newDt) {
|
if (firstAttr.dataType == newDt) {
|
||||||
firstAttr.withNullability(nullable)
|
firstAttr.withNullability(nullable)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -559,52 +559,81 @@ object StructType extends AbstractDataType {
|
||||||
case _ => dt
|
case _ => dt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This leverages `merge` to merge data types for UNION operator by specializing
|
||||||
|
* the handling of struct types to follow UNION semantics.
|
||||||
|
*/
|
||||||
|
private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType =
|
||||||
|
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
|
||||||
|
val leftFields = s1.fields
|
||||||
|
val rightFields = s2.fields
|
||||||
|
require(leftFields.size == rightFields.size, "To merge nullability, " +
|
||||||
|
"two structs must have same number of fields.")
|
||||||
|
|
||||||
|
val newFields = leftFields.zip(rightFields).map {
|
||||||
|
case (leftField, rightField) =>
|
||||||
|
leftField.copy(
|
||||||
|
dataType = unionLikeMerge(leftField.dataType, rightField.dataType),
|
||||||
|
nullable = leftField.nullable || rightField.nullable)
|
||||||
|
}.toSeq
|
||||||
|
StructType(newFields)
|
||||||
|
})
|
||||||
|
|
||||||
private[sql] def merge(left: DataType, right: DataType): DataType =
|
private[sql] def merge(left: DataType, right: DataType): DataType =
|
||||||
|
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
|
||||||
|
val leftFields = s1.fields
|
||||||
|
val rightFields = s2.fields
|
||||||
|
val newFields = mutable.ArrayBuffer.empty[StructField]
|
||||||
|
|
||||||
|
val rightMapped = fieldsMap(rightFields)
|
||||||
|
leftFields.foreach {
|
||||||
|
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
|
||||||
|
rightMapped.get(leftName)
|
||||||
|
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
|
||||||
|
try {
|
||||||
|
leftField.copy(
|
||||||
|
dataType = merge(leftType, rightType),
|
||||||
|
nullable = leftNullable || rightNullable)
|
||||||
|
} catch {
|
||||||
|
case NonFatal(e) =>
|
||||||
|
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.orElse {
|
||||||
|
Some(leftField)
|
||||||
|
}
|
||||||
|
.foreach(newFields += _)
|
||||||
|
}
|
||||||
|
|
||||||
|
val leftMapped = fieldsMap(leftFields)
|
||||||
|
rightFields
|
||||||
|
.filterNot(f => leftMapped.get(f.name).nonEmpty)
|
||||||
|
.foreach { f =>
|
||||||
|
newFields += f
|
||||||
|
}
|
||||||
|
|
||||||
|
StructType(newFields.toSeq)
|
||||||
|
})
|
||||||
|
|
||||||
|
private def mergeInternal(
|
||||||
|
left: DataType,
|
||||||
|
right: DataType,
|
||||||
|
mergeStruct: (StructType, StructType) => StructType): DataType =
|
||||||
(left, right) match {
|
(left, right) match {
|
||||||
case (ArrayType(leftElementType, leftContainsNull),
|
case (ArrayType(leftElementType, leftContainsNull),
|
||||||
ArrayType(rightElementType, rightContainsNull)) =>
|
ArrayType(rightElementType, rightContainsNull)) =>
|
||||||
ArrayType(
|
ArrayType(
|
||||||
merge(leftElementType, rightElementType),
|
mergeInternal(leftElementType, rightElementType, mergeStruct),
|
||||||
leftContainsNull || rightContainsNull)
|
leftContainsNull || rightContainsNull)
|
||||||
|
|
||||||
case (MapType(leftKeyType, leftValueType, leftContainsNull),
|
case (MapType(leftKeyType, leftValueType, leftContainsNull),
|
||||||
MapType(rightKeyType, rightValueType, rightContainsNull)) =>
|
MapType(rightKeyType, rightValueType, rightContainsNull)) =>
|
||||||
MapType(
|
MapType(
|
||||||
merge(leftKeyType, rightKeyType),
|
mergeInternal(leftKeyType, rightKeyType, mergeStruct),
|
||||||
merge(leftValueType, rightValueType),
|
mergeInternal(leftValueType, rightValueType, mergeStruct),
|
||||||
leftContainsNull || rightContainsNull)
|
leftContainsNull || rightContainsNull)
|
||||||
|
|
||||||
case (StructType(leftFields), StructType(rightFields)) =>
|
case (s1: StructType, s2: StructType) => mergeStruct(s1, s2)
|
||||||
val newFields = mutable.ArrayBuffer.empty[StructField]
|
|
||||||
|
|
||||||
val rightMapped = fieldsMap(rightFields)
|
|
||||||
leftFields.foreach {
|
|
||||||
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
|
|
||||||
rightMapped.get(leftName)
|
|
||||||
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
|
|
||||||
try {
|
|
||||||
leftField.copy(
|
|
||||||
dataType = merge(leftType, rightType),
|
|
||||||
nullable = leftNullable || rightNullable)
|
|
||||||
} catch {
|
|
||||||
case NonFatal(e) =>
|
|
||||||
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.orElse {
|
|
||||||
Some(leftField)
|
|
||||||
}
|
|
||||||
.foreach(newFields += _)
|
|
||||||
}
|
|
||||||
|
|
||||||
val leftMapped = fieldsMap(leftFields)
|
|
||||||
rightFields
|
|
||||||
.filterNot(f => leftMapped.get(f.name).nonEmpty)
|
|
||||||
.foreach { f =>
|
|
||||||
newFields += f
|
|
||||||
}
|
|
||||||
|
|
||||||
StructType(newFields.toSeq)
|
|
||||||
|
|
||||||
case (DecimalType.Fixed(leftPrecision, leftScale),
|
case (DecimalType.Fixed(leftPrecision, leftScale),
|
||||||
DecimalType.Fixed(rightPrecision, rightScale)) =>
|
DecimalType.Fixed(rightPrecision, rightScale)) =>
|
||||||
|
|
|
@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
|
||||||
children.map(_.output).transpose.map { attrs =>
|
children.map(_.output).transpose.map { attrs =>
|
||||||
val firstAttr = attrs.head
|
val firstAttr = attrs.head
|
||||||
val nullable = attrs.exists(_.nullable)
|
val nullable = attrs.exists(_.nullable)
|
||||||
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
|
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
|
||||||
if (firstAttr.dataType == newDt) {
|
if (firstAttr.dataType == newDt) {
|
||||||
firstAttr.withNullability(nullable)
|
firstAttr.withNullability(nullable)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1018,6 +1018,64 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
|
||||||
unionDF = df1.unionByName(df2)
|
unionDF = df1.unionByName(df2)
|
||||||
checkAnswer(unionDF, expected)
|
checkAnswer(unionDF, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-36673: Only merge nullability for Unions of struct") {
|
||||||
|
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
|
||||||
|
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))
|
||||||
|
|
||||||
|
val union1 = df1.union(df2)
|
||||||
|
val union2 = df1.unionByName(df2)
|
||||||
|
|
||||||
|
val schema = StructType(Seq(StructField("id", LongType, false),
|
||||||
|
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))
|
||||||
|
|
||||||
|
Seq(union1, union2).foreach { df =>
|
||||||
|
assert(df.schema == schema)
|
||||||
|
assert(df.queryExecution.optimizedPlan.schema == schema)
|
||||||
|
assert(df.queryExecution.executedPlan.schema == schema)
|
||||||
|
|
||||||
|
checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
|
||||||
|
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-36673: Only merge nullability for unionByName of struct") {
|
||||||
|
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
|
||||||
|
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))
|
||||||
|
|
||||||
|
val df = df1.unionByName(df2)
|
||||||
|
|
||||||
|
val schema = StructType(Seq(StructField("id", LongType, false),
|
||||||
|
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))
|
||||||
|
|
||||||
|
assert(df.schema == schema)
|
||||||
|
assert(df.queryExecution.optimizedPlan.schema == schema)
|
||||||
|
assert(df.queryExecution.executedPlan.schema == schema)
|
||||||
|
|
||||||
|
checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
|
||||||
|
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-36673: Union of structs with different orders") {
|
||||||
|
val df1 = spark.range(2).withColumn("nested",
|
||||||
|
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))
|
||||||
|
val df2 = spark.range(2).withColumn("nested",
|
||||||
|
struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))
|
||||||
|
|
||||||
|
val err1 = intercept[AnalysisException](df1.union(df2).collect())
|
||||||
|
|
||||||
|
assert(err1.message
|
||||||
|
.contains("Union can only be performed on tables with the compatible column types"))
|
||||||
|
|
||||||
|
val df3 = spark.range(2).withColumn("nested",
|
||||||
|
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string"))))
|
||||||
|
val df4 = spark.range(2).withColumn("nested",
|
||||||
|
struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1"))))
|
||||||
|
|
||||||
|
val err2 = intercept[AnalysisException](df3.union(df4).collect())
|
||||||
|
assert(err2.message
|
||||||
|
.contains("Union can only be performed on tables with the compatible column types"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
|
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
|
||||||
|
|
Loading…
Reference in a new issue