[SPARK-34649][SQL][DOCS] org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name having a dot
### What changes were proposed in this pull request? Use resolved attributes instead of data-frame fields for replacing values. ### Why are the changes needed? dataframe.na.replace() does not work for column having a dot in the name ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Added unit tests for the same Closes #31769 from amandeep-sharma/master. Authored-by: Amandeep Sharma <happyaman91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b5b198516c
commit
a9c11896a5
|
@ -66,6 +66,8 @@ license: |
|
|||
- In Spark 3.2, the output schema of `SHOW TBLPROPERTIES` becomes `key: string, value: string` whether you specify the table property key or not. In Spark 3.1 and earlier, the output schema of `SHOW TBLPROPERTIES` is `value: string` when you specify the table property key. To restore the old schema with the builtin catalog, you can set `spark.sql.legacy.keepCommandOutputSchema` to `true`.
|
||||
|
||||
- In Spark 3.2, we support typed literals in the partition spec of INSERT and ADD/DROP/RENAME PARTITION. For example, `ADD PARTITION(dt = date'2020-01-01')` adds a partition with date value `2020-01-01`. In Spark 3.1 and earlier, the partition value will be parsed as string value `date '2020-01-01'`, which is an illegal date value, and we add a partition with null value at the end.
|
||||
|
||||
- In Spark 3.2, `DataFrameNaFunctions.replace()` no longer uses exact string match for the input column names, to match the SQL syntax and support qualified column names. Input column name having a dot in the name (not nested) needs to be escaped with backtick \`. Now, it throws `AnalysisException` if the column is not found in the data frame schema. It also throws `IllegalArgumentException` if the input column name is a nested column. In Spark 3.1 and earlier, it used to ignore invalid input column name and nested column name.
|
||||
|
||||
## Upgrading from Spark SQL 3.0 to 3.1
|
||||
|
||||
|
|
|
@ -327,9 +327,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
*/
|
||||
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
|
||||
if (col == "*") {
|
||||
replace0(df.columns, replacement)
|
||||
replace0(df.logicalPlan.output, replacement)
|
||||
} else {
|
||||
replace0(Seq(col), replacement)
|
||||
replace(Seq(col), replacement)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -352,10 +352,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
*
|
||||
* @since 1.3.1
|
||||
*/
|
||||
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
|
||||
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
|
||||
val attrs = cols.map { colName =>
|
||||
// Check column name exists
|
||||
val attr = df.resolve(colName) match {
|
||||
case a: Attribute => a
|
||||
case _ => throw new UnsupportedOperationException(
|
||||
s"Nested field ${colName} is not supported.")
|
||||
}
|
||||
attr
|
||||
}
|
||||
replace0(attrs, replacement)
|
||||
}
|
||||
|
||||
private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
|
||||
if (replacement.isEmpty || cols.isEmpty) {
|
||||
private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]): DataFrame = {
|
||||
if (replacement.isEmpty || attrs.isEmpty) {
|
||||
return df
|
||||
}
|
||||
|
||||
|
@ -379,15 +390,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
case _: String => StringType
|
||||
}
|
||||
|
||||
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
|
||||
val projections = df.schema.fields.map { f =>
|
||||
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
|
||||
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
|
||||
replaceCol(f, replacementMap)
|
||||
} else if (f.dataType == targetColumnType && shouldReplace) {
|
||||
replaceCol(f, replacementMap)
|
||||
val output = df.queryExecution.analyzed.output
|
||||
val projections = output.map { attr =>
|
||||
if (attrs.contains(attr) && (attr.dataType == targetColumnType ||
|
||||
(attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) {
|
||||
replaceCol(attr, replacementMap)
|
||||
} else {
|
||||
df.col(f.name)
|
||||
Column(attr)
|
||||
}
|
||||
}
|
||||
df.select(projections : _*)
|
||||
|
@ -453,13 +462,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
*
|
||||
* TODO: This can be optimized to use broadcast join when replacementMap is large.
|
||||
*/
|
||||
private def replaceCol[K, V](col: StructField, replacementMap: Map[K, V]): Column = {
|
||||
val keyExpr = df.col(col.name).expr
|
||||
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
|
||||
private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]): Column = {
|
||||
def buildExpr(v: Any) = Cast(Literal(v), attr.dataType)
|
||||
val branches = replacementMap.flatMap { case (source, target) =>
|
||||
Seq(Literal(source), buildExpr(target))
|
||||
}.toSeq
|
||||
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
|
||||
new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
|
||||
}
|
||||
|
||||
private def convertToDouble(v: Any): Double = v match {
|
||||
|
|
|
@ -461,7 +461,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34417 - test fillMap() for column with a dot in the name") {
|
||||
test("SPARK-34417: test fillMap() for column with a dot in the name") {
|
||||
val na = "n/a"
|
||||
checkAnswer(
|
||||
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col")
|
||||
|
@ -469,7 +469,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") {
|
||||
test("SPARK-34417: test fillMap() for qualified-column with a dot in the name") {
|
||||
val na = "n/a"
|
||||
checkAnswer(
|
||||
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF")
|
||||
|
@ -477,7 +477,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34417 - test fillMap() for column without a dot in the name" +
|
||||
test("SPARK-34417: test fillMap() for column without a dot in the name" +
|
||||
" and dataframe with another column having a dot in the name") {
|
||||
val na = "n/a"
|
||||
checkAnswer(
|
||||
|
@ -485,4 +485,41 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
.na.fill(Map("Col" -> na)),
|
||||
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34649: replace value of a column with dot in the name") {
|
||||
checkAnswer(
|
||||
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
|
||||
.na.replace("`Col.1`", Map( "n/a" -> "unknown")),
|
||||
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34649: replace value of a qualified-column with dot in the name") {
|
||||
checkAnswer(
|
||||
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2").as("testDf")
|
||||
.na.replace("testDf.`Col.1`", Map( "n/a" -> "unknown")),
|
||||
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34649: replace value of a dataframe having dot in the all column names") {
|
||||
checkAnswer(
|
||||
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
|
||||
.na.replace("*", Map( "n/a" -> "unknown")),
|
||||
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34649: replace value of a column not present in the dataframe") {
|
||||
val df = Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
|
||||
val exception = intercept[AnalysisException] {
|
||||
df.na.replace("aa", Map( "n/a" -> "unknown"))
|
||||
}
|
||||
assert(exception.getMessage.equals("Cannot resolve column name \"aa\" among (Col.1, Col.2)"))
|
||||
}
|
||||
|
||||
test("SPARK-34649: replace value of a nested column") {
|
||||
val df = createDFWithNestedColumns
|
||||
val exception = intercept[UnsupportedOperationException] {
|
||||
df.na.replace("c1.c1-1", Map("b1" ->"a1"))
|
||||
}
|
||||
assert(exception.getMessage.equals("Nested field c1.c1-1 is not supported."))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue