From 895218996abeb94ab3fed1ebc9215fa00c02227b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Sep 2021 21:37:19 +0800 Subject: [PATCH] [SPARK-36673][SQL] Fix incorrect schema of nested types of union ### What changes were proposed in this pull request? This patch proposes to fix incorrect schema of `union`. ### Why are the changes needed? The current `union` result of nested struct columns is incorrect. By definition of `union` API, it should resolve columns by position, not by name. Right now when determining the `output` (aka. the schema) of union plan, we use `merge` API which actually merges two structs (simply think it as concatenate fields from two structs if not overlapping). The merging behavior doesn't match the `union` definition. So currently we get incorrect schema but the query result is correct. We should fix the incorrect schema. ### Does this PR introduce _any_ user-facing change? Yes, fixing a bug of incorrect schema. ### How was this patch tested? Added unit test. Closes #34025 from viirya/SPARK-36673. Authored-by: Liang-Chi Hsieh Signed-off-by: Wenchen Fan (cherry picked from commit cdd7ae937de635bd0dc38e33a8ceafbbf159a75b) Signed-off-by: Wenchen Fan --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 97 ++++++++++++------- .../execution/basicPhysicalOperators.scala | 2 +- .../sql/DataFrameSetOperationsSuite.scala | 58 +++++++++++ 4 files changed, 123 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e456c5d053..50e8d64feb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -307,7 +307,7 @@ case class Union( children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 87ff4eb571..83ee1913da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -559,52 +559,81 @@ object StructType extends AbstractDataType { case _ => dt } + /** + * This leverages `merge` to merge data types for UNION operator by specializing + * the handling of struct types to follow UNION semantics. + */ + private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType = + mergeInternal(left, right, (s1: StructType, s2: StructType) => { + val leftFields = s1.fields + val rightFields = s2.fields + require(leftFields.size == rightFields.size, "To merge nullability, " + + "two structs must have same number of fields.") + + val newFields = leftFields.zip(rightFields).map { + case (leftField, rightField) => + leftField.copy( + dataType = unionLikeMerge(leftField.dataType, rightField.dataType), + nullable = leftField.nullable || rightField.nullable) + }.toSeq + StructType(newFields) + }) + private[sql] def merge(left: DataType, right: DataType): DataType = + mergeInternal(left, right, (s1: StructType, s2: StructType) => { + val leftFields = s1.fields + val rightFields = s2.fields + val newFields = mutable.ArrayBuffer.empty[StructField] + + val rightMapped = fieldsMap(rightFields) + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightMapped.get(leftName) + .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => + try { + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } catch { + case NonFatal(e) => + throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e) + } + } + .orElse { + Some(leftField) + } + .foreach(newFields += _) + } + + val leftMapped = fieldsMap(leftFields) + rightFields + .filterNot(f => leftMapped.get(f.name).nonEmpty) + .foreach { f => + newFields += f + } + + StructType(newFields.toSeq) + }) + + private def mergeInternal( + left: DataType, + right: DataType, + mergeStruct: (StructType, StructType) => StructType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), ArrayType(rightElementType, rightContainsNull)) => ArrayType( - merge(leftElementType, rightElementType), + mergeInternal(leftElementType, rightElementType, mergeStruct), leftContainsNull || rightContainsNull) case (MapType(leftKeyType, leftValueType, leftContainsNull), MapType(rightKeyType, rightValueType, rightContainsNull)) => MapType( - merge(leftKeyType, rightKeyType), - merge(leftValueType, rightValueType), + mergeInternal(leftKeyType, rightKeyType, mergeStruct), + mergeInternal(leftValueType, rightValueType, mergeStruct), leftContainsNull || rightContainsNull) - case (StructType(leftFields), StructType(rightFields)) => - val newFields = mutable.ArrayBuffer.empty[StructField] - - val rightMapped = fieldsMap(rightFields) - leftFields.foreach { - case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightMapped.get(leftName) - .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => - try { - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } catch { - case NonFatal(e) => - throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e) - } - } - .orElse { - Some(leftField) - } - .foreach(newFields += _) - } - - val leftMapped = fieldsMap(leftFields) - rightFields - .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach { f => - newFields += f - } - - StructType(newFields.toSeq) + case (s1: StructType, s2: StructType) => mergeStruct(s1, s2) case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7bd4dc7be1..8e0080a246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index e3259a2460..bd2e91ba94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -1018,6 +1018,64 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { unionDF = df1.unionByName(df2) checkAnswer(unionDF, expected) } + + test("SPARK-36673: Only merge nullability for Unions of struct") { + val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER"))) + val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner"))) + + val union1 = df1.union(df2) + val union2 = df1.unionByName(df2) + + val schema = StructType(Seq(StructField("id", LongType, false), + StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false))) + + Seq(union1, union2).foreach { df => + assert(df.schema == schema) + assert(df.queryExecution.optimizedPlan.schema == schema) + assert(df.queryExecution.executedPlan.schema == schema) + + checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil) + checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil) + } + } + + test("SPARK-36673: Only merge nullability for unionByName of struct") { + val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER"))) + val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner"))) + + val df = df1.unionByName(df2) + + val schema = StructType(Seq(StructField("id", LongType, false), + StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false))) + + assert(df.schema == schema) + assert(df.queryExecution.optimizedPlan.schema == schema) + assert(df.queryExecution.executedPlan.schema == schema) + + checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil) + checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil) + } + + test("SPARK-36673: Union of structs with different orders") { + val df1 = spark.range(2).withColumn("nested", + struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2")))) + val df2 = spark.range(2).withColumn("nested", + struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1")))) + + val err1 = intercept[AnalysisException](df1.union(df2).collect()) + + assert(err1.message + .contains("Union can only be performed on tables with the compatible column types")) + + val df3 = spark.range(2).withColumn("nested", + struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string")))) + val df4 = spark.range(2).withColumn("nested", + struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1")))) + + val err2 = intercept[AnalysisException](df3.union(df4).collect()) + assert(err2.message + .contains("Union can only be performed on tables with the compatible column types")) + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)