[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas
## What changes were proposed in this pull request? A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because of duplicate attributes. This happens because we are not dealing with this specific case in our `dedupAttr` rules. The PR fix the issue by adding the management of the specific case ## How was this patch tested? added UT + manual tests Author: Marco Gaido <marcogaido91@gmail.com> Author: Marco Gaido <mgaido@hortonworks.com> Closes #21737 from mgaido91/SPARK-24208.
This commit is contained in:
parent
592cc84583
commit
ebf4bfb966
|
@ -5925,6 +5925,22 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
||||
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
||||
|
||||
def test_self_join_with_pandas(self):
|
||||
import pyspark.sql.functions as F
|
||||
|
||||
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
|
||||
def dummy_pandas_udf(df):
|
||||
return df[['key', 'col']]
|
||||
|
||||
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
|
||||
Row(key=2, col='C')])
|
||||
dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
|
||||
|
||||
# this was throwing an AnalysisException before SPARK-24208
|
||||
res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
|
||||
F.col('temp0.key') == F.col('temp1.key'))
|
||||
self.assertEquals(res.count(), 5)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not _have_pandas or not _have_pyarrow,
|
||||
|
|
|
@ -738,6 +738,10 @@ class Analyzer(
|
|||
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
|
||||
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
|
||||
|
||||
case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
|
||||
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
|
||||
(oldVersion, oldVersion.copy(output = output.map(_.newInstance())))
|
||||
|
||||
case oldVersion: Generate
|
||||
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
|
||||
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
|
||||
|
|
|
@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
datasetWithUDF.unpersist(true)
|
||||
}
|
||||
|
||||
test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
|
||||
val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
|
||||
"pyUDF",
|
||||
null,
|
||||
StructType(Seq(StructField("s", LongType))),
|
||||
Seq.empty,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
true))
|
||||
val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s")
|
||||
df1.queryExecution.assertAnalyzed()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue