[SPARK-35876][SQL] ArraysZip should retain field names to avoid being re-written by analyzer/optimizer

### What changes were proposed in this pull request?

This PR fixes an issue that field names of structs generated by `arrays_zip` function could be unexpectedly re-written by analyzer/optimizer.
Here is an example.
```
val df = sc.parallelize(Seq((Array(1, 2), Array(3, 4)))).toDF("a1", "b1").selectExpr("arrays_zip(a1, b1) as zipped")
df.printSchema
root
 |-- zipped: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- a1: integer (nullable = true)                                      // OK. a1 is expected name
 |    |    |-- b1: integer (nullable = true)                                      // OK. b1 is expected name

df.explain
== Physical Plan ==
*(1) Project [arrays_zip(_1#3, _2#4) AS zipped#12]               // Not OK. field names are re-written as _1 and _2 respectively

df.write.parquet("/tmp/test.parquet")
val df2 = spark.read.parquet("/tmp/test.parquet")

df2.printSchema
root
 |-- zipped: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: integer (nullable = true)                                      // Not OK. a1 is expected but got _1
 |    |    |-- _2: integer (nullable = true)                                      // Not OK. b1 is expected but got _2
```

This issue happens when aliases are eliminated by `AliasHelper.replaceAliasButKeepName` or `AliasHelper.trimNonTopLevelAliases` called via analyzer/optimizer
b89cd8d75a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala (L883)
b89cd8d75a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala (L3759)
I investigated functions which can be affected this issue but I found only `arrays_zip` so far.

To fix this issue, this PR changes the definition of `ArraysZip` to retain field names to avoid being re-written by analyzer/optimizer.

### Why are the changes needed?

This is apparently a bug.

### Does this PR introduce _any_ user-facing change?

No. After this change, the field names are no longer re-written but it should be expected behavior for users.

### How was this patch tested?

New tests.

Closes #33106 from sarutak/arrays-zip-retain-names.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Kousuke Saruta 2021-06-29 12:28:41 +09:00 committed by Hyukjin Kwon
parent 620fde4767
commit 880bbd6aaa
5 changed files with 118 additions and 14 deletions

View file

@ -261,7 +261,7 @@ class Analyzer(override val catalogManager: CatalogManager)
AddMetadataColumns ::
DeduplicateRelations ::
ResolveReferences ::
ResolveCreateNamedStruct ::
ResolveExpressionsWithNamePlaceholders ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
@ -3881,11 +3881,19 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
/**
* Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
* Resolve expressions if they contains [[NamePlaceholder]]s.
*/
object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
object ResolveExpressionsWithNamePlaceholders extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
_.containsPattern(CREATE_NAMED_STRUCT), ruleId) {
_.containsAnyPattern(ARRAYS_ZIP, CREATE_NAMED_STRUCT), ruleId) {
case e: ArraysZip if !e.resolved =>
val names = e.children.zip(e.names).map {
case (e: NamedExpression, NamePlaceholder) if e.resolved =>
Literal(e.name)
case (_, other) => other
}
ArraysZip(e.children, names)
case e: CreateNamedStruct if !e.resolved =>
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>

View file

@ -23,11 +23,11 @@ import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@ -181,16 +181,35 @@ case class MapKeys(child: Expression)
""",
group = "array_funcs",
since = "2.4.0")
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
extends Expression with ExpectsInputTypes {
def this(children: Seq[Expression]) = {
this(
children,
children.zipWithIndex.map {
case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
case (e: NamedExpression, _) if e.resolved => Literal(e.name)
case (e: NamedExpression, _) => NamePlaceholder
case (_, idx) => Literal(idx.toString)
})
}
if (children.size != names.size) {
throw new IllegalArgumentException(
"The numbers of zipped arrays and field names should be the same")
}
final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_ZIP)
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && names.forall(_.resolved)
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
@transient override lazy val dataType: DataType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
val fields = arrayElementTypes.zip(names).map {
case (elementType, Literal(name, StringType)) =>
StructField(name.toString, elementType, nullable = true)
}
ArrayType(StructType(fields), containsNull = false)
}
@ -332,6 +351,12 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
copy(children = newChildren)
}
object ArraysZip {
def apply(children: Seq[Expression]): ArraysZip = {
new ArraysZip(children)
}
}
/**
* Returns an unordered array containing the values of the map.
*/

View file

@ -80,7 +80,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" ::
"org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" ::
"org.apache.spark.sql.catalyst.analysis.EliminateUnions" ::
"org.apache.spark.sql.catalyst.analysis.ResolveCreateNamedStruct" ::
"org.apache.spark.sql.catalyst.analysis.ResolveExpressionsWithNamePlaceholders" ::
"org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveCoalesceHints" ::
"org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" ::
"org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" ::

View file

@ -26,6 +26,7 @@ object TreePattern extends Enumeration {
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
val AND_OR: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val APPEND_COLUMNS: Value = Value
val AVERAGE: Value = Value

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql
import java.io.File
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
@ -24,7 +25,8 @@ import scala.util.Random
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
@ -552,6 +554,74 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Row(Seq(Row(0, 1, 2, 3, 4, 5))))
}
test("SPARK-35876: arrays_zip should retain field names") {
withTempDir { dir =>
val df = spark.sparkContext.parallelize(
Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6)))).toDF("val1", "val2")
val qualifiedDF = df.as("foo")
// Fields are UnresolvedAttribute
val zippedDF1 = qualifiedDF.select(arrays_zip($"foo.val1", $"foo.val2") as "zipped")
val maybeAlias1 = zippedDF1.queryExecution.logical.expressions.head
assert(maybeAlias1.isInstanceOf[Alias])
val maybeArraysZip1 = maybeAlias1.children.head
assert(maybeArraysZip1.isInstanceOf[ArraysZip])
assert(maybeArraysZip1.children.forall(_.isInstanceOf[UnresolvedAttribute]))
val file1 = new File(dir, "arrays_zip1")
zippedDF1.write.parquet(file1.getAbsolutePath)
val restoredDF1 = spark.read.parquet(file1.getAbsolutePath)
val fieldNames1 = restoredDF1.schema.head.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType].fieldNames
assert(fieldNames1.toSeq === Seq("val1", "val2"))
// Fields are resolved NamedExpression
val zippedDF2 = df.select(arrays_zip(df("val1"), df("val2")) as "zipped")
val maybeAlias2 = zippedDF2.queryExecution.logical.expressions.head
assert(maybeAlias2.isInstanceOf[Alias])
val maybeArraysZip2 = maybeAlias2.children.head
assert(maybeArraysZip2.isInstanceOf[ArraysZip])
assert(maybeArraysZip2.children.forall(
e => e.isInstanceOf[AttributeReference] && e.resolved))
val file2 = new File(dir, "arrays_zip2")
zippedDF2.write.parquet(file2.getAbsolutePath)
val restoredDF2 = spark.read.parquet(file2.getAbsolutePath)
val fieldNames2 = restoredDF2.schema.head.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType].fieldNames
assert(fieldNames2.toSeq === Seq("val1", "val2"))
// Fields are unresolved NamedExpression
val zippedDF3 = df.select(arrays_zip($"val1" as "val3", $"val2" as "val4") as "zipped")
val maybeAlias3 = zippedDF3.queryExecution.logical.expressions.head
assert(maybeAlias3.isInstanceOf[Alias])
val maybeArraysZip3 = maybeAlias3.children.head
assert(maybeArraysZip3.isInstanceOf[ArraysZip])
assert(maybeArraysZip3.children.forall(e => e.isInstanceOf[Alias] && !e.resolved))
val file3 = new File(dir, "arrays_zip3")
zippedDF3.write.parquet(file3.getAbsolutePath)
val restoredDF3 = spark.read.parquet(file3.getAbsolutePath)
val fieldNames3 = restoredDF3.schema.head.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType].fieldNames
assert(fieldNames3.toSeq === Seq("val3", "val4"))
// Fields are neither UnresolvedAttribute nor NamedExpression
val zippedDF4 = df.select(arrays_zip(array_sort($"val1"), array_sort($"val2")) as "zipped")
val maybeAlias4 = zippedDF4.queryExecution.logical.expressions.head
assert(maybeAlias4.isInstanceOf[Alias])
val maybeArraysZip4 = maybeAlias4.children.head
assert(maybeArraysZip4.isInstanceOf[ArraysZip])
assert(maybeArraysZip4.children.forall {
case _: UnresolvedAttribute | _: NamedExpression => false
case _ => true
})
val file4 = new File(dir, "arrays_zip4")
zippedDF4.write.parquet(file4.getAbsolutePath)
val restoredDF4 = spark.read.parquet(file4.getAbsolutePath)
val fieldNames4 = restoredDF4.schema.head.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType].fieldNames
assert(fieldNames4.toSeq === Seq("0", "1"))
}
}
def testSizeOfMap(sizeOfNull: Any): Unit = {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),