[SPARK-6562][SQL] DataFrame.replace
Supports replacing values with other values in DataFrames. Python support should be in a separate pull request. Author: Reynold Xin <rxin@databricks.com> Closes #5282 from rxin/df-na-replace and squashes the following commits: 4b72434 [Reynold Xin] Removed println. c8d9946 [Reynold Xin] col -> cols fbb3c21 [Reynold Xin] [SPARK-6562][SQL] DataFrame.replace
This commit is contained in:
parent
9294044985
commit
68d1faa3c0
|
@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
*/
|
||||
def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
|
||||
|
||||
/**
|
||||
* Replaces values matching keys in `replacement` map with the corresponding values.
|
||||
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
|
||||
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
|
||||
*
|
||||
* {{{
|
||||
* import com.google.common.collect.ImmutableMap;
|
||||
*
|
||||
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
|
||||
* df.replace("height", ImmutableMap.of(1.0, 2.0));
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
|
||||
* df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
|
||||
* df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
|
||||
* }}}
|
||||
*
|
||||
* @param col name of the column to apply the value replacement
|
||||
* @param replacement value replacement map, as explained above
|
||||
*/
|
||||
def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
|
||||
replace[T](col, replacement.toMap : Map[T, T])
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces values matching keys in `replacement` map with the corresponding values.
|
||||
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
|
||||
*
|
||||
* {{{
|
||||
* import com.google.common.collect.ImmutableMap;
|
||||
*
|
||||
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
|
||||
* df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
|
||||
* df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
|
||||
* }}}
|
||||
*
|
||||
* @param cols list of columns to apply the value replacement
|
||||
* @param replacement value replacement map, as explained above
|
||||
*/
|
||||
def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = {
|
||||
replace(cols.toSeq, replacement.toMap)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific) Replaces values matching keys in `replacement` map.
|
||||
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
|
||||
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
|
||||
*
|
||||
* {{{
|
||||
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
|
||||
* df.replace("height", Map(1.0 -> 2.0))
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
|
||||
* df.replace("name", Map("UNKNOWN" -> "unnamed")
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
|
||||
* df.replace("*", Map("UNKNOWN" -> "unnamed")
|
||||
* }}}
|
||||
*
|
||||
* @param col name of the column to apply the value replacement
|
||||
* @param replacement value replacement map, as explained above
|
||||
*/
|
||||
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
|
||||
if (col == "*") {
|
||||
replace0(df.columns, replacement)
|
||||
} else {
|
||||
replace0(Seq(col), replacement)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* (Scala-specific) Replaces values matching keys in `replacement` map.
|
||||
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
|
||||
*
|
||||
* {{{
|
||||
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
|
||||
* df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
|
||||
*
|
||||
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
|
||||
* df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed");
|
||||
* }}}
|
||||
*
|
||||
* @param cols list of columns to apply the value replacement
|
||||
* @param replacement value replacement map, as explained above
|
||||
*/
|
||||
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
|
||||
|
||||
private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
|
||||
if (replacement.isEmpty || cols.isEmpty) {
|
||||
return df
|
||||
}
|
||||
|
||||
// replacementMap is either Map[String, String] or Map[Double, Double]
|
||||
val replacementMap: Map[_, _] = replacement.head._2 match {
|
||||
case v: String => replacement
|
||||
case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
|
||||
}
|
||||
|
||||
// targetColumnType is either DoubleType or StringType
|
||||
val targetColumnType = replacement.head._1 match {
|
||||
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
|
||||
case _: String => StringType
|
||||
}
|
||||
|
||||
val columnEquals = df.sqlContext.analyzer.resolver
|
||||
val projections = df.schema.fields.map { f =>
|
||||
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
|
||||
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
|
||||
replaceCol(f, replacementMap)
|
||||
} else if (f.dataType == targetColumnType && shouldReplace) {
|
||||
replaceCol(f, replacementMap)
|
||||
} else {
|
||||
df.col(f.name)
|
||||
}
|
||||
}
|
||||
df.select(projections : _*)
|
||||
}
|
||||
|
||||
private def fill0(values: Seq[(String, Any)]): DataFrame = {
|
||||
// Error handling
|
||||
values.foreach { case (colName, replaceValue) =>
|
||||
|
@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
private def fillCol[T](col: StructField, replacement: T): Column = {
|
||||
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a [[Column]] expression that replaces value matching key in `replacementMap` with
|
||||
* value in `replacementMap`, using [[CaseWhen]].
|
||||
*
|
||||
* TODO: This can be optimized to use broadcast join when replacementMap is large.
|
||||
*/
|
||||
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
|
||||
val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
|
||||
df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
|
||||
lit(target).cast(col.dataType).expr :: Nil
|
||||
}.toSeq
|
||||
new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
|
||||
}
|
||||
|
||||
private def convertToDouble(v: Any): Double = v match {
|
||||
case v: Float => v.toDouble
|
||||
case v: Double => v
|
||||
case v: Long => v.toDouble
|
||||
case v: Int => v.toDouble
|
||||
case v => throw new IllegalArgumentException(
|
||||
s"Unsupported value type ${v.getClass.getName} ($v).")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -154,4 +154,38 @@ class DataFrameNaFunctionsSuite extends QueryTest {
|
|||
))),
|
||||
Row("test", null, 1, 2.2))
|
||||
}
|
||||
|
||||
test("replace") {
|
||||
val input = createDF()
|
||||
|
||||
// Replace two numeric columns: age and height
|
||||
val out = input.na.replace(Seq("age", "height"), Map(
|
||||
16 -> 61,
|
||||
60 -> 6,
|
||||
164.3 -> 461.3 // Alice is really tall
|
||||
))
|
||||
|
||||
checkAnswer(
|
||||
out,
|
||||
Row("Bob", 61, 176.5) ::
|
||||
Row("Alice", null, 461.3) ::
|
||||
Row("David", 6, null) ::
|
||||
Row("Amy", null, null) ::
|
||||
Row(null, null, null) :: Nil)
|
||||
|
||||
// Replace only the age column
|
||||
val out1 = input.na.replace("age", Map(
|
||||
16 -> 61,
|
||||
60 -> 6,
|
||||
164.3 -> 461.3 // Alice is really tall
|
||||
))
|
||||
|
||||
checkAnswer(
|
||||
out1,
|
||||
Row("Bob", 61, 176.5) ::
|
||||
Row("Alice", null, 164.3) ::
|
||||
Row("David", 6, null) ::
|
||||
Row("Amy", null, null) ::
|
||||
Row(null, null, null) :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue