[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

## What changes were proposed in this pull request?

A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because of duplicate attributes. This happens because we are not dealing with this specific case in our `dedupAttr` rules.

The PR fix the issue by adding the management of the specific case

## How was this patch tested?

added UT + manual tests

Author: Marco Gaido <marcogaido91@gmail.com>
Author: Marco Gaido <mgaido@hortonworks.com>

Closes #21737 from mgaido91/SPARK-24208.
This commit is contained in:
Marco Gaido 2018-07-11 09:29:19 -07:00 committed by Xiao Li
parent 592cc84583
commit ebf4bfb966
3 changed files with 32 additions and 0 deletions

View file

@ -5925,6 +5925,22 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
def test_self_join_with_pandas(self):
import pyspark.sql.functions as F
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[['key', 'col']]
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
Row(key=2, col='C')])
dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
# this was throwing an AnalysisException before SPARK-24208
res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
F.col('temp0.key') == F.col('temp1.key'))
self.assertEquals(res.count(), 5)
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,

View file

@ -738,6 +738,10 @@ class Analyzer(
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(output = output.map(_.newInstance())))
case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

View file

@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
}
datasetWithUDF.unpersist(true)
}
test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
"pyUDF",
null,
StructType(Seq(StructField("s", LongType))),
Seq.empty,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
true))
val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s")
df1.queryExecution.assertAnalyzed()
}
}