[SPARK-31256][SQL] DataFrameNaFunctions.drop should work for nested columns
### What changes were proposed in this pull request? #26700 removed the ability to drop a row whose nested column value is null. For example, for the following `df`: ``` val schema = new StructType() .add("c1", new StructType() .add("c1-1", StringType) .add("c1-2", StringType)) val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null)) val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) df.show +--------+ | c1| +--------+ | [, a2]| |[b1, b2]| | null| +--------+ ``` In Spark 2.4.4, ``` df.na.drop("any", Seq("c1.c1-1")).show +--------+ | c1| +--------+ |[b1, b2]| +--------+ ``` In Spark 2.4.5 or Spark 3.0.0-preview2, if nested columns are specified, they are ignored. ``` df.na.drop("any", Seq("c1.c1-1")).show +--------+ | c1| +--------+ | [, a2]| |[b1, b2]| | null| +--------+ ``` ### Why are the changes needed? This seems like a regression. ### Does this PR introduce any user-facing change? Now, the nested column can be specified: ``` df.na.drop("any", Seq("c1.c1-1")).show +--------+ | c1| +--------+ |[b1, b2]| +--------+ ``` Also, if `*` is specified as a column, it will throw an `AnalysisException` that `*` cannot be resolved, which was the behavior in 2.4.4. Currently, in master, it has no effect. ### How was this patch tested? Updated existing tests. Closes #28266 from imback82/SPARK-31256. Authored-by: Terry Kim <yuminkim@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
bc212df610
commit
d7499aed9c
|
@ -89,7 +89,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
* @since 1.3.1
|
||||
*/
|
||||
def drop(how: String, cols: Seq[String]): DataFrame = {
|
||||
drop0(how, toAttributes(cols))
|
||||
drop0(how, cols.map(df.resolve(_)))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -115,7 +115,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
* @since 1.3.1
|
||||
*/
|
||||
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
|
||||
drop0(minNonNulls, toAttributes(cols))
|
||||
drop0(minNonNulls, cols.map(df.resolve(_)))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -480,7 +480,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
df.queryExecution.analyzed.output
|
||||
}
|
||||
|
||||
private def drop0(how: String, cols: Seq[Attribute]): DataFrame = {
|
||||
private def drop0(how: String, cols: Seq[NamedExpression]): DataFrame = {
|
||||
how.toLowerCase(Locale.ROOT) match {
|
||||
case "any" => drop0(cols.size, cols)
|
||||
case "all" => drop0(1, cols)
|
||||
|
@ -488,12 +488,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
|
|||
}
|
||||
}
|
||||
|
||||
private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = {
|
||||
private def drop0(minNonNulls: Int, cols: Seq[NamedExpression]): DataFrame = {
|
||||
// Filtering condition:
|
||||
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
|
||||
val predicate = AtLeastNNonNulls(
|
||||
minNonNulls,
|
||||
outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) })
|
||||
val predicate = AtLeastNNonNulls(minNonNulls, cols)
|
||||
df.filter(Column(predicate))
|
||||
}
|
||||
|
||||
|
|
|
@ -45,6 +45,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
).toDF("int", "long", "short", "byte", "float", "double")
|
||||
}
|
||||
|
||||
def createDFWithNestedColumns: DataFrame = {
|
||||
val schema = new StructType()
|
||||
.add("c1", new StructType()
|
||||
.add("c1-1", StringType)
|
||||
.add("c1-2", StringType))
|
||||
val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null))
|
||||
spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(data), schema)
|
||||
}
|
||||
|
||||
test("drop") {
|
||||
val input = createDF()
|
||||
val rows = input.collect()
|
||||
|
@ -275,33 +285,35 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
assert(message.contains("Reference 'f2' is ambiguous"))
|
||||
}
|
||||
|
||||
test("fill/drop with col(*)") {
|
||||
test("fill with col(*)") {
|
||||
val df = createDF()
|
||||
// If columns are specified with "*", they are ignored.
|
||||
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
|
||||
checkAnswer(df.na.drop("any", Seq("*")), df.collect())
|
||||
}
|
||||
|
||||
test("fill/drop with nested columns") {
|
||||
val schema = new StructType()
|
||||
.add("c1", new StructType()
|
||||
.add("c1-1", StringType)
|
||||
.add("c1-2", StringType))
|
||||
test("drop with col(*)") {
|
||||
val df = createDF()
|
||||
val exception = intercept[AnalysisException] {
|
||||
df.na.drop("any", Seq("*"))
|
||||
}
|
||||
assert(exception.getMessage.contains("Cannot resolve column name \"*\""))
|
||||
}
|
||||
|
||||
val data = Seq(
|
||||
Row(Row(null, "a2")),
|
||||
Row(Row("b1", "b2")),
|
||||
Row(null))
|
||||
test("fill with nested columns") {
|
||||
val df = createDFWithNestedColumns
|
||||
|
||||
val df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(data), schema)
|
||||
// Nested columns are ignored for fill().
|
||||
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), df)
|
||||
}
|
||||
|
||||
checkAnswer(df.select("c1.c1-1"),
|
||||
Row(null) :: Row("b1") :: Row(null) :: Nil)
|
||||
test("drop with nested columns") {
|
||||
val df = createDFWithNestedColumns
|
||||
|
||||
// Nested columns are ignored for fill() and drop().
|
||||
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
|
||||
checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data)
|
||||
// Rows with the specified nested columns whose null values are dropped.
|
||||
assert(df.count == 3)
|
||||
checkAnswer(
|
||||
df.na.drop("any", Seq("c1.c1-1")),
|
||||
Seq(Row(Row("b1", "b2"))))
|
||||
}
|
||||
|
||||
test("replace") {
|
||||
|
|
Loading…
Reference in a new issue