[SPARK-31256][SQL] DataFrameNaFunctions.drop should work for nested columns

### What changes were proposed in this pull request?

#26700 removed the ability to drop a row whose nested column value is null.

For example, for the following `df`:
```
val schema = new StructType()
  .add("c1", new StructType()
    .add("c1-1", StringType)
    .add("c1-2", StringType))
val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null))
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
df.show
+--------+
|      c1|
+--------+
|  [, a2]|
|[b1, b2]|
|    null|
+--------+
```
In Spark 2.4.4,
```
df.na.drop("any", Seq("c1.c1-1")).show
+--------+
|      c1|
+--------+
|[b1, b2]|
+--------+
```
In Spark 2.4.5 or Spark 3.0.0-preview2, if nested columns are specified, they are ignored.
```
df.na.drop("any", Seq("c1.c1-1")).show
+--------+
|      c1|
+--------+
|  [, a2]|
|[b1, b2]|
|    null|
+--------+
```
### Why are the changes needed?

This seems like a regression.

### Does this PR introduce any user-facing change?

Now, the nested column can be specified:
```
df.na.drop("any", Seq("c1.c1-1")).show
+--------+
|      c1|
+--------+
|[b1, b2]|
+--------+
```

Also, if `*` is specified as a column, it will throw an `AnalysisException` that `*` cannot be resolved, which was the behavior in 2.4.4. Currently, in master, it has no effect.

### How was this patch tested?

Updated existing tests.

Closes #28266 from imback82/SPARK-31256.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Terry Kim 2020-04-20 02:59:09 +00:00 committed by Wenchen Fan
parent bc212df610
commit d7499aed9c
2 changed files with 35 additions and 25 deletions

View file

@ -89,7 +89,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(how: String, cols: Seq[String]): DataFrame = {
drop0(how, toAttributes(cols))
drop0(how, cols.map(df.resolve(_)))
}
/**
@ -115,7 +115,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
drop0(minNonNulls, toAttributes(cols))
drop0(minNonNulls, cols.map(df.resolve(_)))
}
/**
@ -480,7 +480,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.queryExecution.analyzed.output
}
private def drop0(how: String, cols: Seq[Attribute]): DataFrame = {
private def drop0(how: String, cols: Seq[NamedExpression]): DataFrame = {
how.toLowerCase(Locale.ROOT) match {
case "any" => drop0(cols.size, cols)
case "all" => drop0(1, cols)
@ -488,12 +488,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}
}
private def drop0(minNonNulls: Int, cols: Seq[Attribute]): DataFrame = {
private def drop0(minNonNulls: Int, cols: Seq[NamedExpression]): DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(
minNonNulls,
outputAttributes.filter{ col => cols.exists(_.semanticEquals(col)) })
val predicate = AtLeastNNonNulls(minNonNulls, cols)
df.filter(Column(predicate))
}

View file

@ -45,6 +45,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
).toDF("int", "long", "short", "byte", "float", "double")
}
def createDFWithNestedColumns: DataFrame = {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
.add("c1-2", StringType))
val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null))
spark.createDataFrame(
spark.sparkContext.parallelize(data), schema)
}
test("drop") {
val input = createDF()
val rows = input.collect()
@ -275,33 +285,35 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
assert(message.contains("Reference 'f2' is ambiguous"))
}
test("fill/drop with col(*)") {
test("fill with col(*)") {
val df = createDF()
// If columns are specified with "*", they are ignored.
checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
checkAnswer(df.na.drop("any", Seq("*")), df.collect())
}
test("fill/drop with nested columns") {
val schema = new StructType()
.add("c1", new StructType()
.add("c1-1", StringType)
.add("c1-2", StringType))
test("drop with col(*)") {
val df = createDF()
val exception = intercept[AnalysisException] {
df.na.drop("any", Seq("*"))
}
assert(exception.getMessage.contains("Cannot resolve column name \"*\""))
}
val data = Seq(
Row(Row(null, "a2")),
Row(Row("b1", "b2")),
Row(null))
test("fill with nested columns") {
val df = createDFWithNestedColumns
val df = spark.createDataFrame(
spark.sparkContext.parallelize(data), schema)
// Nested columns are ignored for fill().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), df)
}
checkAnswer(df.select("c1.c1-1"),
Row(null) :: Row("b1") :: Row(null) :: Nil)
test("drop with nested columns") {
val df = createDFWithNestedColumns
// Nested columns are ignored for fill() and drop().
checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
checkAnswer(df.na.drop("any", Seq("c1.c1-1")), data)
// Rows with the specified nested columns whose null values are dropped.
assert(df.count == 3)
checkAnswer(
df.na.drop("any", Seq("c1.c1-1")),
Seq(Row(Row("b1", "b2"))))
}
test("replace") {