diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 085cc29289..0c75eda7a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -106,8 +106,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { } protected lazy val (buildKeys, streamedKeys) = { - require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), - "Join keys from two sides should have same types") + require(leftKeys.length == rightKeys.length && + leftKeys.map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") buildSide match { case BuildLeft => (leftKeys, rightKeys) case BuildRight => (rightKeys, leftKeys) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0b4f43b723..b463a76a74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, Join, JoinHint, LogicalPlan, Project} @@ -29,6 +31,7 @@ import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ class DataFrameJoinSuite extends QueryTest with SharedSparkSession @@ -418,4 +421,40 @@ class DataFrameJoinSuite extends QueryTest } } } + + test("SPARK-32693: Compare two dataframes with same schema except nullable property") { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", IntegerType, false) :: + StructField("c", IntegerType, false) :: Nil) + val rowSeq1: List[Row] = List(Row(10, 1, 1), Row(10, 50, 2)) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", IntegerType) :: + StructField("c", IntegerType) :: Nil) + val rowSeq2: List[Row] = List(Row(10, 1, 1)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + checkAnswer(df1.except(df2), Row(10, 50, 2)) + + val schema3 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", IntegerType, false) :: + StructField("c", IntegerType, false) :: + StructField("d", schema1, false) :: Nil) + val rowSeq3: List[Row] = List(Row(10, 1, 1, Row(10, 1, 1)), Row(10, 50, 2, Row(10, 50, 2))) + val df3 = spark.createDataFrame(rowSeq3.asJava, schema3) + + val schema4 = StructType( + StructField("a", IntegerType) :: + StructField("b", IntegerType) :: + StructField("b", IntegerType) :: + StructField("d", schema2) :: Nil) + val rowSeq4: List[Row] = List(Row(10, 1, 1, Row(10, 1, 1))) + val df4 = spark.createDataFrame(rowSeq4.asJava, schema4) + + checkAnswer(df3.except(df4), Row(10, 50, 2, Row(10, 50, 2))) + } }