[SPARK-35876][SQL] ArraysZip should retain field names to avoid being re-written by analyzer/optimizer
### What changes were proposed in this pull request? This PR fixes an issue that field names of structs generated by `arrays_zip` function could be unexpectedly re-written by analyzer/optimizer. Here is an example. ``` val df = sc.parallelize(Seq((Array(1, 2), Array(3, 4)))).toDF("a1", "b1").selectExpr("arrays_zip(a1, b1) as zipped") df.printSchema root |-- zipped: array (nullable = true) | |-- element: struct (containsNull = false) | | |-- a1: integer (nullable = true) // OK. a1 is expected name | | |-- b1: integer (nullable = true) // OK. b1 is expected name df.explain == Physical Plan == *(1) Project [arrays_zip(_1#3, _2#4) AS zipped#12] // Not OK. field names are re-written as _1 and _2 respectively df.write.parquet("/tmp/test.parquet") val df2 = spark.read.parquet("/tmp/test.parquet") df2.printSchema root |-- zipped: array (nullable = true) | |-- element: struct (containsNull = true) | | |-- _1: integer (nullable = true) // Not OK. a1 is expected but got _1 | | |-- _2: integer (nullable = true) // Not OK. b1 is expected but got _2 ``` This issue happens when aliases are eliminated by `AliasHelper.replaceAliasButKeepName` or `AliasHelper.trimNonTopLevelAliases` called via analyzer/optimizerb89cd8d75a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala (L883)
b89cd8d75a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala (L3759)
I investigated functions which can be affected this issue but I found only `arrays_zip` so far. To fix this issue, this PR changes the definition of `ArraysZip` to retain field names to avoid being re-written by analyzer/optimizer. ### Why are the changes needed? This is apparently a bug. ### Does this PR introduce _any_ user-facing change? No. After this change, the field names are no longer re-written but it should be expected behavior for users. ### How was this patch tested? New tests. Closes #33106 from sarutak/arrays-zip-retain-names. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
620fde4767
commit
880bbd6aaa
|
@ -261,7 +261,7 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
AddMetadataColumns ::
|
||||
DeduplicateRelations ::
|
||||
ResolveReferences ::
|
||||
ResolveCreateNamedStruct ::
|
||||
ResolveExpressionsWithNamePlaceholders ::
|
||||
ResolveDeserializer ::
|
||||
ResolveNewInstance ::
|
||||
ResolveUpCast ::
|
||||
|
@ -3881,11 +3881,19 @@ object TimeWindowing extends Rule[LogicalPlan] {
|
|||
}
|
||||
|
||||
/**
|
||||
* Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
|
||||
* Resolve expressions if they contains [[NamePlaceholder]]s.
|
||||
*/
|
||||
object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
|
||||
object ResolveExpressionsWithNamePlaceholders extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
|
||||
_.containsPattern(CREATE_NAMED_STRUCT), ruleId) {
|
||||
_.containsAnyPattern(ARRAYS_ZIP, CREATE_NAMED_STRUCT), ruleId) {
|
||||
case e: ArraysZip if !e.resolved =>
|
||||
val names = e.children.zip(e.names).map {
|
||||
case (e: NamedExpression, NamePlaceholder) if e.resolved =>
|
||||
Literal(e.name)
|
||||
case (_, other) => other
|
||||
}
|
||||
ArraysZip(e.children, names)
|
||||
|
||||
case e: CreateNamedStruct if !e.resolved =>
|
||||
val children = e.children.grouped(2).flatMap {
|
||||
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
|
||||
|
|
|
@ -23,11 +23,11 @@ import scala.collection.mutable
|
|||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedSeed}
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
|
||||
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
|
||||
|
@ -181,16 +181,35 @@ case class MapKeys(child: Expression)
|
|||
""",
|
||||
group = "array_funcs",
|
||||
since = "2.4.0")
|
||||
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
|
||||
case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
|
||||
extends Expression with ExpectsInputTypes {
|
||||
|
||||
def this(children: Seq[Expression]) = {
|
||||
this(
|
||||
children,
|
||||
children.zipWithIndex.map {
|
||||
case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
|
||||
case (e: NamedExpression, _) if e.resolved => Literal(e.name)
|
||||
case (e: NamedExpression, _) => NamePlaceholder
|
||||
case (_, idx) => Literal(idx.toString)
|
||||
})
|
||||
}
|
||||
|
||||
if (children.size != names.size) {
|
||||
throw new IllegalArgumentException(
|
||||
"The numbers of zipped arrays and field names should be the same")
|
||||
}
|
||||
|
||||
final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_ZIP)
|
||||
|
||||
override lazy val resolved: Boolean =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && names.forall(_.resolved)
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
|
||||
|
||||
@transient override lazy val dataType: DataType = {
|
||||
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
|
||||
case ((expr: NamedExpression, elementType), _) =>
|
||||
StructField(expr.name, elementType, nullable = true)
|
||||
case ((_, elementType), idx) =>
|
||||
StructField(idx.toString, elementType, nullable = true)
|
||||
val fields = arrayElementTypes.zip(names).map {
|
||||
case (elementType, Literal(name, StringType)) =>
|
||||
StructField(name.toString, elementType, nullable = true)
|
||||
}
|
||||
ArrayType(StructType(fields), containsNull = false)
|
||||
}
|
||||
|
@ -332,6 +351,12 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
|
|||
copy(children = newChildren)
|
||||
}
|
||||
|
||||
object ArraysZip {
|
||||
def apply(children: Seq[Expression]): ArraysZip = {
|
||||
new ArraysZip(children)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an unordered array containing the values of the map.
|
||||
*/
|
||||
|
|
|
@ -80,7 +80,7 @@ object RuleIdCollection {
|
|||
"org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.EliminateUnions" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.ResolveCreateNamedStruct" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.ResolveExpressionsWithNamePlaceholders" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveCoalesceHints" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" ::
|
||||
"org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" ::
|
||||
|
|
|
@ -26,6 +26,7 @@ object TreePattern extends Enumeration {
|
|||
val AGGREGATE_EXPRESSION = Value(0)
|
||||
val ALIAS: Value = Value
|
||||
val AND_OR: Value = Value
|
||||
val ARRAYS_ZIP: Value = Value
|
||||
val ATTRIBUTE_REFERENCE: Value = Value
|
||||
val APPEND_COLUMNS: Value = Value
|
||||
val AVERAGE: Value = Value
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.io.File
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
|
@ -24,7 +25,8 @@ import scala.util.Random
|
|||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, ArraysZip, AttributeReference, Expression, NamedExpression, UnaryExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
|
||||
|
@ -552,6 +554,74 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
Row(Seq(Row(0, 1, 2, 3, 4, 5))))
|
||||
}
|
||||
|
||||
test("SPARK-35876: arrays_zip should retain field names") {
|
||||
withTempDir { dir =>
|
||||
val df = spark.sparkContext.parallelize(
|
||||
Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6)))).toDF("val1", "val2")
|
||||
val qualifiedDF = df.as("foo")
|
||||
|
||||
// Fields are UnresolvedAttribute
|
||||
val zippedDF1 = qualifiedDF.select(arrays_zip($"foo.val1", $"foo.val2") as "zipped")
|
||||
val maybeAlias1 = zippedDF1.queryExecution.logical.expressions.head
|
||||
assert(maybeAlias1.isInstanceOf[Alias])
|
||||
val maybeArraysZip1 = maybeAlias1.children.head
|
||||
assert(maybeArraysZip1.isInstanceOf[ArraysZip])
|
||||
assert(maybeArraysZip1.children.forall(_.isInstanceOf[UnresolvedAttribute]))
|
||||
val file1 = new File(dir, "arrays_zip1")
|
||||
zippedDF1.write.parquet(file1.getAbsolutePath)
|
||||
val restoredDF1 = spark.read.parquet(file1.getAbsolutePath)
|
||||
val fieldNames1 = restoredDF1.schema.head.dataType.asInstanceOf[ArrayType]
|
||||
.elementType.asInstanceOf[StructType].fieldNames
|
||||
assert(fieldNames1.toSeq === Seq("val1", "val2"))
|
||||
|
||||
// Fields are resolved NamedExpression
|
||||
val zippedDF2 = df.select(arrays_zip(df("val1"), df("val2")) as "zipped")
|
||||
val maybeAlias2 = zippedDF2.queryExecution.logical.expressions.head
|
||||
assert(maybeAlias2.isInstanceOf[Alias])
|
||||
val maybeArraysZip2 = maybeAlias2.children.head
|
||||
assert(maybeArraysZip2.isInstanceOf[ArraysZip])
|
||||
assert(maybeArraysZip2.children.forall(
|
||||
e => e.isInstanceOf[AttributeReference] && e.resolved))
|
||||
val file2 = new File(dir, "arrays_zip2")
|
||||
zippedDF2.write.parquet(file2.getAbsolutePath)
|
||||
val restoredDF2 = spark.read.parquet(file2.getAbsolutePath)
|
||||
val fieldNames2 = restoredDF2.schema.head.dataType.asInstanceOf[ArrayType]
|
||||
.elementType.asInstanceOf[StructType].fieldNames
|
||||
assert(fieldNames2.toSeq === Seq("val1", "val2"))
|
||||
|
||||
// Fields are unresolved NamedExpression
|
||||
val zippedDF3 = df.select(arrays_zip($"val1" as "val3", $"val2" as "val4") as "zipped")
|
||||
val maybeAlias3 = zippedDF3.queryExecution.logical.expressions.head
|
||||
assert(maybeAlias3.isInstanceOf[Alias])
|
||||
val maybeArraysZip3 = maybeAlias3.children.head
|
||||
assert(maybeArraysZip3.isInstanceOf[ArraysZip])
|
||||
assert(maybeArraysZip3.children.forall(e => e.isInstanceOf[Alias] && !e.resolved))
|
||||
val file3 = new File(dir, "arrays_zip3")
|
||||
zippedDF3.write.parquet(file3.getAbsolutePath)
|
||||
val restoredDF3 = spark.read.parquet(file3.getAbsolutePath)
|
||||
val fieldNames3 = restoredDF3.schema.head.dataType.asInstanceOf[ArrayType]
|
||||
.elementType.asInstanceOf[StructType].fieldNames
|
||||
assert(fieldNames3.toSeq === Seq("val3", "val4"))
|
||||
|
||||
// Fields are neither UnresolvedAttribute nor NamedExpression
|
||||
val zippedDF4 = df.select(arrays_zip(array_sort($"val1"), array_sort($"val2")) as "zipped")
|
||||
val maybeAlias4 = zippedDF4.queryExecution.logical.expressions.head
|
||||
assert(maybeAlias4.isInstanceOf[Alias])
|
||||
val maybeArraysZip4 = maybeAlias4.children.head
|
||||
assert(maybeArraysZip4.isInstanceOf[ArraysZip])
|
||||
assert(maybeArraysZip4.children.forall {
|
||||
case _: UnresolvedAttribute | _: NamedExpression => false
|
||||
case _ => true
|
||||
})
|
||||
val file4 = new File(dir, "arrays_zip4")
|
||||
zippedDF4.write.parquet(file4.getAbsolutePath)
|
||||
val restoredDF4 = spark.read.parquet(file4.getAbsolutePath)
|
||||
val fieldNames4 = restoredDF4.schema.head.dataType.asInstanceOf[ArrayType]
|
||||
.elementType.asInstanceOf[StructType].fieldNames
|
||||
assert(fieldNames4.toSeq === Seq("0", "1"))
|
||||
}
|
||||
}
|
||||
|
||||
def testSizeOfMap(sizeOfNull: Any): Unit = {
|
||||
val df = Seq(
|
||||
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
|
||||
|
|
Loading…
Reference in a new issue