[SPARK-32511][SQL] Add dropFields method to Column class

### What changes were proposed in this pull request?

Added a new `dropFields` method to the `Column` class.
This method should allow users to drop a `StructField` in a `StructType` column (with similar semantics to the `drop` method on `Dataset`).

### Why are the changes needed?

Often Spark users have to work with deeply nested data e.g. to fix a data quality issue with an existing `StructField`. To do this with the existing Spark APIs, users have to rebuild the entire struct column.

For example, let's say you have the following deeply nested data structure which has a data quality issue (`5` is missing):
```
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

val data = spark.createDataFrame(sc.parallelize(
      Seq(Row(Row(Row(1, 2, 3), Row(Row(4, null, 6), Row(7, 8, 9), Row(10, 11, 12)), Row(13, 14, 15))))),
      StructType(Seq(
        StructField("a", StructType(Seq(
          StructField("a", StructType(Seq(
            StructField("a", IntegerType),
            StructField("b", IntegerType),
            StructField("c", IntegerType)))),
          StructField("b", StructType(Seq(
            StructField("a", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType)))),
            StructField("b", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType)))),
            StructField("c", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType))))
          ))),
          StructField("c", StructType(Seq(
            StructField("a", IntegerType),
            StructField("b", IntegerType),
            StructField("c", IntegerType))))
        )))))).cache

data.show(false)
+---------------------------------+
|a                                |
+---------------------------------+
|[[1, 2, 3], [[4,, 6], [7, 8, 9]]]|
+---------------------------------+
```
Currently, to drop the missing value users would have to do something like this:
```
val result = data.withColumn("a",
  struct(
    $"a.a",
    struct(
      struct(
        $"a.b.a.a",
        $"a.b.a.c"
      ).as("a"),
      $"a.b.b",
      $"a.b.c"
    ).as("b"),
    $"a.c"
  ))

result.show(false)
+---------------------------------------------------------------+
|a                                                              |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+
```
As you can see above, with the existing methods users must call the `struct` function and list all fields, including fields they don't want to change. This is not ideal as:
>this leads to complex, fragile code that cannot survive schema evolution.
[SPARK-16483](https://issues.apache.org/jira/browse/SPARK-16483)

In contrast, with the method added in this PR, a user could simply do something like this to get the same result:
```
val result = data.withColumn("a", 'a.dropFields("b.a.b"))
result.show(false)
+---------------------------------------------------------------+
|a                                                              |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+

```

This is the second of maybe 3 methods that could be added to the `Column` class to make it easier to manipulate nested data.
Other methods under discussion in [SPARK-22231](https://issues.apache.org/jira/browse/SPARK-22231) include `withFieldRenamed`.
However, this should be added in a separate PR.

### Does this PR introduce _any_ user-facing change?

Only one minor change. If the user submits the following query:
```
df.withColumn("a", $"a".withField(null, null))
```
instead of throwing:
```
java.lang.IllegalArgumentException: requirement failed: fieldName cannot be null
```
it will now throw:
```
java.lang.IllegalArgumentException: requirement failed: col cannot be null
```
I don't believe its should be an issue to change this because:
- neither message is incorrect
- Spark 3.1.0 has yet to be released

but please feel free to correct me if I am wrong.

### How was this patch tested?

New unit tests were added. Jenkins must pass them.

### Related JIRAs:
More discussion on this topic can be found here:
- https://issues.apache.org/jira/browse/SPARK-22231
- https://issues.apache.org/jira/browse/SPARK-16483

Closes #29322 from fqaiser94/SPARK-32511.

Lead-authored-by: fqaiser94@gmail.com <fqaiser94@gmail.com>
Co-authored-by: fqaiser94 <fqaiser94@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
fqaiser94@gmail.com 2020-08-13 03:28:25 +00:00 committed by Wenchen Fan
parent 08d86ebc05
commit 0c850c71e7
8 changed files with 579 additions and 123 deletions

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@ -541,57 +541,97 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}
/**
* Adds/replaces field in struct by name.
* Represents an operation to be applied to the fields of a struct.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
trait StructFieldsOperation {
assert(names.length == valExprs.length)
val resolver: Resolver = SQLConf.get.resolver
/**
* Returns an updated list of expressions which will ultimately be used as the children argument
* for [[CreateNamedStruct]].
*/
def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)]
}
/**
* Add or replace a field by name.
*
* We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its
* children, and thereby enable the analyzer to resolve and transform valExpr as necessary.
*/
case class WithField(name: String, valExpr: Expression)
extends Unevaluable with StructFieldsOperation {
override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
if (exprs.exists(x => resolver(x._1, name))) {
exprs.map {
case (existingName, _) if resolver(existingName, name) => (name, valExpr)
case x => x
}
} else {
exprs :+ (name, valExpr)
}
override def children: Seq[Expression] = valExpr :: Nil
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def prettyName: String = "WithField"
}
/**
* Drop a field by name.
*/
case class DropField(name: String) extends StructFieldsOperation {
override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
exprs.filterNot(expr => resolver(expr._1, name))
}
/**
* Updates fields in struct by name.
*/
case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation])
extends Unevaluable {
override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " +
dataType.catalogString)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def children: Seq[Expression] = structExpr +: valExprs
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
}
override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]
override def nullable: Boolean = structExpr.nullable
override def prettyName: String = "with_fields"
override def prettyName: String = "update_fields"
lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
private lazy val existingExprs: Seq[(String, Expression)] =
structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i))
}
val addOrReplaceExprs = names.zip(valExprs)
private lazy val newExprs = fieldOps.foldLeft(existingExprs)((exprs, op) => op(exprs))
val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
private lazy val createNamedStructExpr = CreateNamedStruct(newExprs.flatMap {
case (name, expr) => Seq(Literal(name), expr)
})
lazy val evalExpr: Expression = if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, createNamedStructExpr.dataType), createNamedStructExpr)
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
createNamedStructExpr
}
}

View file

@ -39,17 +39,17 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
case GetStructField(u: UpdateFields, ordinal, maybeName) =>
val name = u.dataType(ordinal).name
val matches = u.fieldOps.collect { case w: WithField if w.name == name => w }
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
matches.last._2
matches.last.valExpr
} else {
GetStructField(struct, ordinal, maybeName)
GetStructField(u.structExpr, ordinal, maybeName)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>

View file

@ -108,7 +108,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
CombineWithFields,
CombineUpdateFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
@ -217,8 +217,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
}
@ -251,7 +250,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil
/**
* Optimize all the subqueries inside expression.

View file

@ -17,26 +17,26 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.expressions.UpdateFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}
/**
* Replaces [[WithFields]] expression with an evaluable expression.
* Replaces [[UpdateFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
case u: UpdateFields => u.evalExpr
}
}

View file

@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
class CombineWithFieldsSuite extends PlanTest {
class CombineUpdateFieldsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
}
private val testRelation = LocalRelation('a.struct('a1.int))
test("combines two WithFields") {
test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil), "out")())
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
Nil), "out")())
.analyze
comparePlans(optimized, correctAnswer)
}
test("combines three WithFields") {
test("combines three adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil),
WithField("d1", Literal(6)) :: Nil), "out")())
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
WithField("d1", Literal(6)) :: Nil), "out")())
.analyze
comparePlans(optimized, correctAnswer)

View file

@ -453,49 +453,72 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}
private val structAttr = 'struct1.struct('a.int)
private val structAttr = 'struct1.struct('a.int, 'b.int)
private val testStructRelation = LocalRelation(structAttr)
test("simplify GetStructField on WithFields that is not changing the attribute being extracted") {
val query = testStructRelation.select(
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt")
val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt")
checkRule(query, expected)
test("simplify GetStructField on UpdateFields that is not modifying the attribute being " +
"extracted") {
// add attribute, extract an attribute from the original struct
val query1 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("b", Literal(1)) :: Nil), 0, None) as "outerAtt")
// drop attribute, extract an attribute from the original struct
val query2 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("b") ::
Nil), 0, None) as "outerAtt")
// drop attribute, add attribute, extract an attribute from the original struct
val query3 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("b") ::
WithField("c", Literal(2)) :: Nil), 0, None) as "outerAtt")
// drop attribute, add attribute, extract an attribute from the original struct
val query4 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("a") ::
WithField("a", Literal(1)) :: Nil), 0, None) as "outerAtt")
val expected = testStructRelation.select(GetStructField('struct1, 0, None) as "outerAtt")
Seq(query1, query2, query3, query4).foreach {
query => checkRule(query, expected)
}
}
test("simplify GetStructField on WithFields that is changing the attribute being extracted") {
val query = testStructRelation.select(
GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt")
test("simplify GetStructField on UpdateFields that is modifying the attribute being extracted") {
// add attribute, and then extract it
val query1 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("c", Literal(1)) :: Nil), 2, None) as "outerAtt")
// replace attribute, and then extract it
val query2 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("b", Literal(1)) :: Nil), 1, None) as "outerAtt")
// add attribute, replace the same attribute, and then extract it
val query3 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("c", Literal(2)) :: WithField("c", Literal(1)) :: Nil), 2, None) as "outerAtt")
// replace the same attribute twice, and then extract it
val query4 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("b", Literal(2)) :: WithField("b", Literal(1)) :: Nil), 1, None) as "outerAtt")
// replace attribute, drop another attribute, extract the replaced attribute
val query5 = testStructRelation.select(GetStructField(UpdateFields('struct1,
WithField("a", Literal(1)) :: DropField("b") :: Nil), 0, None) as "outerAtt")
// drop attribute, add attribute with same name, and then extract the added attribute
val query6 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("a") ::
WithField("a", Literal(1)) :: Nil), 1, None) as "outerAtt")
val expected = testStructRelation.select(Literal(1) as "outerAtt")
checkRule(query, expected)
Seq(query1, query2, query3, query4, query5, query6).foreach {
query => checkRule(query, expected)
}
}
test(
"simplify GetStructField on WithFields that is changing the attribute being extracted twice") {
test("simplify multiple GetStructField on the same UpdateFields expression") {
val query = testStructRelation
.select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1,
Some("b")) as "outerAtt")
val expected = testStructRelation.select(Literal(2) as "outerAtt")
checkRule(query, expected)
}
test("collapse multiple GetStructField on the same WithFields") {
val query = testStructRelation
.select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2")
.select(UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2")
.select(
GetStructField('struct2, 0, Some("a")) as "struct1A",
GetStructField('struct2, 1, Some("b")) as "struct1B")
val expected = testStructRelation.select(
GetStructField('struct1, 0, Some("a")) as "struct1A",
Literal(2) as "struct1B")
val expected = testStructRelation
.select(GetStructField('struct1, 0, Some("a")) as "struct1A", Literal(2) as "struct1B")
checkRule(query, expected)
}
test("collapse multiple GetStructField on different WithFields") {
test("simplify multiple GetStructField on different UpdateFields expressions") {
val query = testStructRelation
.select(
WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2",
WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3")
UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2",
UpdateFields('struct1, WithField("b", Literal(3)) :: Nil) as "struct3")
.select(
GetStructField('struct2, 0, Some("a")) as "struct2A",
GetStructField('struct2, 1, Some("b")) as "struct2B",
@ -503,10 +526,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField('struct3, 1, Some("b")) as "struct3B")
val expected = testStructRelation
.select(
GetStructField('struct1, 0, Some("a")) as "struct2A",
Literal(2) as "struct2B",
GetStructField('struct1, 0, Some("a")) as "struct3A",
Literal(3) as "struct3B")
GetStructField('struct1, 0, Some("a")) as "struct2A", Literal(2) as "struct2B",
GetStructField('struct1, 0, Some("a")) as "struct3A", Literal(3) as "struct3B")
checkRule(query, expected)
}
}

View file

@ -906,34 +906,84 @@ class Column(val expr: Expression) extends Logging {
*/
// scalastyle:on line.size.limit
def withField(fieldName: String, col: Column): Column = withExpr {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")
updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, col.expr))
}
val nameParts = if (fieldName.isEmpty) {
// scalastyle:off line.size.limit
/**
* An expression that drops fields in `StructType` by name.
*
* {{{
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
* df.select($"struct_col".dropFields("b"))
* // result: {"a":1}
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
* df.select($"struct_col".dropFields("c"))
* // result: {"a":1,"b":2}
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
* df.select($"struct_col".dropFields("b", "c"))
* // result: {"a":1}
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
* df.select($"struct_col".dropFields("a", "b"))
* // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct
*
* val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
* df.select($"struct_col".dropFields("b"))
* // result: null of type struct<a:int>
*
* val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
* df.select($"struct_col".dropFields("b"))
* // result: {"a":1}
*
* val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
* df.select($"struct_col".dropFields("a.b"))
* // result: {"a":{"a":1}}
*
* val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
* df.select($"struct_col".dropFields("a.c"))
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
* }}}
*
* @group expr_ops
* @since 3.1.0
*/
// scalastyle:on line.size.limit
def dropFields(fieldNames: String*): Column = withExpr {
def dropField(expr: Expression, fieldName: String): UpdateFields =
updateFieldsHelper(expr, nameParts(fieldName), name => DropField(name))
fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) {
(resExpr, fieldName) => dropField(resExpr, fieldName)
}
}
private def nameParts(fieldName: String): Seq[String] = {
require(fieldName != null, "fieldName cannot be null")
if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(expr, nameParts, Nil, col.expr)
}
private def withFieldHelper(
private def updateFieldsHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression) : WithFields = {
val name = namePartsRemaining.head
valueFunc: String => StructFieldsOperation): UpdateFields = {
val fieldName = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil)
UpdateFields(struct, valueFunc(fieldName) :: Nil)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name
val newValue = withFieldHelper(
struct = UnresolvedExtractValue(struct, Literal(name)),
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value)
WithFields(struct, name :: Nil, newValue :: Nil)
val newValue = updateFieldsHelper(
struct = UnresolvedExtractValue(struct, Literal(fieldName)),
namePartsRemaining = namePartsRemaining.tail,
valueFunc = valueFunc)
UpdateFields(struct, WithField(fieldName, newValue) :: Nil)
}
}

View file

@ -984,7 +984,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
intercept[IllegalArgumentException] {
structLevel1.withColumn("a", $"a".withField(null, null))
}.getMessage should include("fieldName cannot be null")
}.getMessage should include("col cannot be null")
}
test("withField should throw an exception if any intermediate structs don't exist") {
@ -1452,4 +1452,353 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
.select($"struct_col".withField("a.c", lit(3)))
}.getMessage should include("Ambiguous reference to fields")
}
test("dropFields should throw an exception if called on a non-StructType column") {
intercept[AnalysisException] {
testData.withColumn("key", $"key".dropFields("a"))
}.getMessage should include("struct argument should be struct type, got: int")
}
test("dropFields should throw an exception if fieldName argument is null") {
intercept[IllegalArgumentException] {
structLevel1.withColumn("a", $"a".dropFields(null))
}.getMessage should include("fieldName cannot be null")
}
test("dropFields should throw an exception if any intermediate structs don't exist") {
intercept[AnalysisException] {
structLevel2.withColumn("a", 'a.dropFields("x.b"))
}.getMessage should include("No such struct field x in a")
intercept[AnalysisException] {
structLevel3.withColumn("a", 'a.dropFields("a.x.b"))
}.getMessage should include("No such struct field x in a")
}
test("dropFields should throw an exception if intermediate field is not a struct") {
intercept[AnalysisException] {
structLevel1.withColumn("a", 'a.dropFields("b.a"))
}.getMessage should include("struct argument should be struct type, got: int")
}
test("dropFields should throw an exception if intermediate field reference is ambiguous") {
intercept[AnalysisException] {
val structLevel2: DataFrame = spark.createDataFrame(
sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", structType, nullable = false),
StructField("a", structType, nullable = false))),
nullable = false))))
structLevel2.withColumn("a", 'a.dropFields("a.b"))
}.getMessage should include("Ambiguous reference to fields")
}
test("dropFields should drop field in struct") {
checkAnswerAndSchema(
structLevel1.withColumn("a", 'a.dropFields("b")),
Row(Row(1, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = false))))
}
test("dropFields should drop field in null struct") {
checkAnswerAndSchema(
nullStructLevel1.withColumn("a", $"a".dropFields("b")),
Row(null) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = true))))
}
test("dropFields should drop multiple fields in struct") {
Seq(
structLevel1.withColumn("a", $"a".dropFields("b", "c")),
structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c"))
).foreach { df =>
checkAnswerAndSchema(
df,
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false))),
nullable = false))))
}
}
test("dropFields should throw an exception if no fields will be left in struct") {
intercept[AnalysisException] {
structLevel1.withColumn("a", 'a.dropFields("a", "b", "c"))
}.getMessage should include("cannot drop all fields in struct")
}
test("dropFields should drop field in nested struct") {
checkAnswerAndSchema(
structLevel2.withColumn("a", 'a.dropFields("a.b")),
Row(Row(Row(1, 3))) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = false))),
nullable = false))))
}
test("dropFields should drop multiple fields in nested struct") {
checkAnswerAndSchema(
structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")),
Row(Row(Row(1))) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false))),
nullable = false))),
nullable = false))))
}
test("dropFields should drop field in nested null struct") {
checkAnswerAndSchema(
nullStructLevel2.withColumn("a", $"a".dropFields("a.b")),
Row(Row(null)) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = true))),
nullable = false))))
}
test("dropFields should drop multiple fields in nested null struct") {
checkAnswerAndSchema(
nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")),
Row(Row(null)) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false))),
nullable = true))),
nullable = false))))
}
test("dropFields should drop field in deeply nested struct") {
checkAnswerAndSchema(
structLevel3.withColumn("a", 'a.dropFields("a.a.b")),
Row(Row(Row(Row(1, 3)))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = false))),
nullable = false))),
nullable = false))))
}
test("dropFields should drop all fields with given name in struct") {
val structLevel1 = spark.createDataFrame(
sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil),
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
structLevel1.withColumn("a", 'a.dropFields("b")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false))),
nullable = false))))
}
test("dropFields should drop field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswerAndSchema(
mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("B", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
Row(Row(1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false))),
nullable = false))))
}
}
test("dropFields should not drop field in struct because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswerAndSchema(
mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")),
Row(Row(1, 1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("B", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")),
Row(Row(1, 1)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("B", IntegerType, nullable = false))),
nullable = false))))
}
}
test("dropFields should drop nested field in struct even if casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkAnswerAndSchema(
mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")),
Row(Row(Row(1), Row(1, 1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("A", StructType(Seq(
StructField("b", IntegerType, nullable = false))),
nullable = false),
StructField("B", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = false))),
nullable = false))),
nullable = false))))
checkAnswerAndSchema(
mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")),
Row(Row(Row(1, 1), Row(1))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = false))),
nullable = false),
StructField("b", StructType(Seq(
StructField("b", IntegerType, nullable = false))),
nullable = false))),
nullable = false))))
}
}
test("dropFields should throw an exception because casing is different") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
intercept[AnalysisException] {
mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a"))
}.getMessage should include("No such struct field A in a, B")
intercept[AnalysisException] {
mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a"))
}.getMessage should include("No such struct field b in a, B")
}
}
test("dropFields should drop only fields that exist") {
checkAnswerAndSchema(
structLevel1.withColumn("a", 'a.dropFields("d")),
Row(Row(1, null, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = true),
StructField("c", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
structLevel1.withColumn("a", 'a.dropFields("b", "d")),
Row(Row(1, 3)) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")),
Row(Row(Row(1, 3))) :: Nil,
StructType(
Seq(StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))),
nullable = false))),
nullable = false))))
}
test("dropFields should drop multiple fields at arbitrary levels of nesting in a single call") {
val df: DataFrame = spark.createDataFrame(
sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil),
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", structType, nullable = false),
StructField("b", IntegerType, nullable = false))),
nullable = false))))
checkAnswerAndSchema(
df.withColumn("a", $"a".dropFields("a.b", "b")),
Row(Row(Row(1, 3))) :: Nil,
StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("c", IntegerType, nullable = false))), nullable = false))),
nullable = false))))
}
test("dropFields user-facing examples") {
checkAnswer(
sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
.select($"struct_col".dropFields("b")),
Row(Row(1)))
checkAnswer(
sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
.select($"struct_col".dropFields("c")),
Row(Row(1, 2)))
checkAnswer(
sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
.select($"struct_col".dropFields("b", "c")),
Row(Row(1)))
intercept[AnalysisException] {
sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
.select($"struct_col".dropFields("a", "b"))
}.getMessage should include("cannot drop all fields in struct")
checkAnswer(
sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
.select($"struct_col".dropFields("b")),
Row(null))
checkAnswer(
sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
.select($"struct_col".dropFields("b")),
Row(Row(1)))
checkAnswer(
sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
.select($"struct_col".dropFields("a.b")),
Row(Row(Row(1))))
intercept[AnalysisException] {
sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
.select($"struct_col".dropFields("a.c"))
}.getMessage should include("Ambiguous reference to fields")
}
}