[SPARK-34649][SQL][DOCS] org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name having a dot

### What changes were proposed in this pull request?

Use resolved attributes instead of data-frame fields for replacing values.

### Why are the changes needed?

dataframe.na.replace() does not work for column having a dot in the name

### Does this PR introduce _any_ user-facing change?

None

### How was this patch tested?

Added unit tests for the same

Closes #31769 from amandeep-sharma/master.

Authored-by: Amandeep Sharma <happyaman91@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Amandeep Sharma 2021-03-09 11:47:01 +00:00 committed by Wenchen Fan
parent b5b198516c
commit a9c11896a5
3 changed files with 67 additions and 20 deletions

View file

@ -66,6 +66,8 @@ license: |
- In Spark 3.2, the output schema of `SHOW TBLPROPERTIES` becomes `key: string, value: string` whether you specify the table property key or not. In Spark 3.1 and earlier, the output schema of `SHOW TBLPROPERTIES` is `value: string` when you specify the table property key. To restore the old schema with the builtin catalog, you can set `spark.sql.legacy.keepCommandOutputSchema` to `true`.
- In Spark 3.2, we support typed literals in the partition spec of INSERT and ADD/DROP/RENAME PARTITION. For example, `ADD PARTITION(dt = date'2020-01-01')` adds a partition with date value `2020-01-01`. In Spark 3.1 and earlier, the partition value will be parsed as string value `date '2020-01-01'`, which is an illegal date value, and we add a partition with null value at the end.
- In Spark 3.2, `DataFrameNaFunctions.replace()` no longer uses exact string match for the input column names, to match the SQL syntax and support qualified column names. Input column name having a dot in the name (not nested) needs to be escaped with backtick \`. Now, it throws `AnalysisException` if the column is not found in the data frame schema. It also throws `IllegalArgumentException` if the input column name is a nested column. In Spark 3.1 and earlier, it used to ignore invalid input column name and nested column name.
## Upgrading from Spark SQL 3.0 to 3.1

View file

@ -327,9 +327,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
if (col == "*") {
replace0(df.columns, replacement)
replace0(df.logicalPlan.output, replacement)
} else {
replace0(Seq(col), replacement)
replace(Seq(col), replacement)
}
}
@ -352,10 +352,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
val attrs = cols.map { colName =>
// Check column name exists
val attr = df.resolve(colName) match {
case a: Attribute => a
case _ => throw new UnsupportedOperationException(
s"Nested field ${colName} is not supported.")
}
attr
}
replace0(attrs, replacement)
}
private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
if (replacement.isEmpty || cols.isEmpty) {
private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]): DataFrame = {
if (replacement.isEmpty || attrs.isEmpty) {
return df
}
@ -379,15 +390,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case _: String => StringType
}
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
replaceCol(f, replacementMap)
} else if (f.dataType == targetColumnType && shouldReplace) {
replaceCol(f, replacementMap)
val output = df.queryExecution.analyzed.output
val projections = output.map { attr =>
if (attrs.contains(attr) && (attr.dataType == targetColumnType ||
(attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) {
replaceCol(attr, replacementMap)
} else {
df.col(f.name)
Column(attr)
}
}
df.select(projections : _*)
@ -453,13 +462,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol[K, V](col: StructField, replacementMap: Map[K, V]): Column = {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]): Column = {
def buildExpr(v: Any) = Cast(Literal(v), attr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(Literal(source), buildExpr(target))
}.toSeq
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
}
private def convertToDouble(v: Any): Double = v match {

View file

@ -461,7 +461,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
test("SPARK-34417 - test fillMap() for column with a dot in the name") {
test("SPARK-34417: test fillMap() for column with a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col")
@ -469,7 +469,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") {
test("SPARK-34417: test fillMap() for qualified-column with a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF")
@ -477,7 +477,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
test("SPARK-34417 - test fillMap() for column without a dot in the name" +
test("SPARK-34417: test fillMap() for column without a dot in the name" +
" and dataframe with another column having a dot in the name") {
val na = "n/a"
checkAnswer(
@ -485,4 +485,41 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
.na.fill(Map("Col" -> na)),
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
test("SPARK-34649: replace value of a column with dot in the name") {
checkAnswer(
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
.na.replace("`Col.1`", Map( "n/a" -> "unknown")),
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
}
test("SPARK-34649: replace value of a qualified-column with dot in the name") {
checkAnswer(
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2").as("testDf")
.na.replace("testDf.`Col.1`", Map( "n/a" -> "unknown")),
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
}
test("SPARK-34649: replace value of a dataframe having dot in the all column names") {
checkAnswer(
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
.na.replace("*", Map( "n/a" -> "unknown")),
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
}
test("SPARK-34649: replace value of a column not present in the dataframe") {
val df = Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
val exception = intercept[AnalysisException] {
df.na.replace("aa", Map( "n/a" -> "unknown"))
}
assert(exception.getMessage.equals("Cannot resolve column name \"aa\" among (Col.1, Col.2)"))
}
test("SPARK-34649: replace value of a nested column") {
val df = createDFWithNestedColumns
val exception = intercept[UnsupportedOperationException] {
df.na.replace("c1.c1-1", Map("b1" ->"a1"))
}
assert(exception.getMessage.equals("Nested field c1.c1-1 is not supported."))
}
}