diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index bbf0ac1dd8..308bb96502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -395,10 +395,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { private def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling - values.foreach { case (colName, replaceValue) => + val attrToValue = AttributeMap(values.map { case (colName, replaceValue) => // Check column name exists - df.resolve(colName) - + val attr = df.resolve(colName) match { + case a: Attribute => a + case _ => throw new UnsupportedOperationException( + s"Nested field ${colName} is not supported.") + } // Check data type replaceValue match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String => @@ -406,31 +409,29 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case _ => throw new IllegalArgumentException( s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).") } - } + attr -> replaceValue + }) - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => - v match { - case v: jl.Float => fillCol[Float](f, v) - case v: jl.Double => fillCol[Double](f, v) - case v: jl.Long => fillCol[Long](f, v) - case v: jl.Integer => fillCol[Integer](f, v) - case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) - case v: String => fillCol[String](f, v) - } - }.getOrElse(df.col(f.name)) + val output = df.queryExecution.analyzed.output + val projections = output.map { + attr => attrToValue.get(attr).map { + case v: jl.Float => fillCol[Float](attr, v) + case v: jl.Double => fillCol[Double](attr, v) + case v: jl.Long => fillCol[Long](attr, v) + case v: jl.Integer => fillCol[Integer](attr, v) + case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue()) + case v: String => fillCol[String](attr, v) + }.getOrElse(Column(attr)) } df.select(projections : _*) } /** - * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. - * It selects a column based on its name. + * Returns a [[Column]] expression that replaces null value in column defined by `attr` + * with `replacement`. */ - private def fillCol[T](col: StructField, replacement: T): Column = { - val quotedColName = "`" + col.name + "`" - fillCol(col.dataType, col.name, df.col(quotedColName), replacement) + private def fillCol[T](attr: Attribute, replacement: T): Column = { + fillCol(attr.dataType, attr.name, Column(attr), replacement) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 091877f7ca..23c2349f89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -460,4 +460,29 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) } + + test("SPARK-34417 - test fillMap() for column with a dot in the name") { + val na = "n/a" + checkAnswer( + Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col") + .na.fill(Map("`ColWith.Dot`" -> na)), + Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) + } + + test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") { + val na = "n/a" + checkAnswer( + Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF") + .na.fill(Map("testDF.`ColWith.Dot`" -> na)), + Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) + } + + test("SPARK-34417 - test fillMap() for column without a dot in the name" + + " and dataframe with another column having a dot in the name") { + val na = "n/a" + checkAnswer( + Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("Col", "ColWith.Dot") + .na.fill(Map("Col" -> na)), + Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil) + } }