[SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames
## What changes were proposed in this pull request? This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions. ``` >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> w = Window.partitionBy('id') >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.5| | 1| 2.0| 1.5| | 2| 3.0| 6.0| | 2| 5.0| 6.0| | 2|10.0| 6.0| +---+----+------+ ``` The scope of this PR is somewhat limited in terms of: (1) Only supports unbounded window, which acts essentially as group by. (2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping) Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR. ## How was this patch tested? WindowPandasUDFTests Author: Li Jin <ice.xelloss@gmail.com> Closes #21082 from icexelloss/SPARK-22239-window-udf.
This commit is contained in:
parent
f53818d35b
commit
9786ce66c5
|
@ -40,6 +40,7 @@ private[spark] object PythonEvalType {
|
|||
val SQL_SCALAR_PANDAS_UDF = 200
|
||||
val SQL_GROUPED_MAP_PANDAS_UDF = 201
|
||||
val SQL_GROUPED_AGG_PANDAS_UDF = 202
|
||||
val SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
|
||||
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
||||
case NON_UDF => "NON_UDF"
|
||||
|
@ -47,6 +48,7 @@ private[spark] object PythonEvalType {
|
|||
case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
|
||||
case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
|
||||
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
|
||||
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ class PythonEvalType(object):
|
|||
SQL_SCALAR_PANDAS_UDF = 200
|
||||
SQL_GROUPED_MAP_PANDAS_UDF = 201
|
||||
SQL_GROUPED_AGG_PANDAS_UDF = 202
|
||||
SQL_WINDOW_AGG_PANDAS_UDF = 203
|
||||
|
||||
|
||||
def portable_hash(x):
|
||||
|
|
|
@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
The returned scalar can be either a python primitive type, e.g., `int` or `float`
|
||||
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
|
||||
|
||||
:class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
|
||||
output types.
|
||||
:class:`MapType` and :class:`StructType` are currently not supported as output types.
|
||||
|
||||
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`
|
||||
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
|
||||
:class:`pyspark.sql.Window`
|
||||
|
||||
This example shows using grouped aggregated UDFs with groupby:
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
|
@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
| 2| 6.0|
|
||||
+---+-----------+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.agg`
|
||||
This example shows using grouped aggregated UDFs as window functions. Note that only
|
||||
unbounded window frame is supported at the moment:
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> from pyspark.sql import Window
|
||||
>>> df = spark.createDataFrame(
|
||||
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
||||
... ("id", "v"))
|
||||
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def mean_udf(v):
|
||||
... return v.mean()
|
||||
>>> w = Window.partitionBy('id') \\
|
||||
... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
|
||||
+---+----+------+
|
||||
| id| v|mean_v|
|
||||
+---+----+------+
|
||||
| 1| 1.0| 1.5|
|
||||
| 1| 2.0| 1.5|
|
||||
| 2| 3.0| 6.0|
|
||||
| 2| 5.0| 6.0|
|
||||
| 2|10.0| 6.0|
|
||||
+---+----+------+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
|
||||
|
||||
.. note:: The user-defined functions are considered deterministic by default. Due to
|
||||
optimization, duplicate invocations may be eliminated or the function may even be invoked
|
||||
|
|
|
@ -5454,6 +5454,15 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
expected1 = df.groupby(df.id).agg(sum(df.v))
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
|
||||
def test_array_type(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
||||
df = self.data
|
||||
|
||||
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
|
||||
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
|
||||
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
|
||||
|
||||
def test_invalid_args(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
|
@ -5479,6 +5488,235 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
||||
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not _have_pandas or not _have_pyarrow,
|
||||
_pandas_requirement_message or _pyarrow_requirement_message)
|
||||
class WindowPandasUDFTests(ReusedSQLTestCase):
|
||||
@property
|
||||
def data(self):
|
||||
from pyspark.sql.functions import array, explode, col, lit
|
||||
return self.spark.range(10).toDF('id') \
|
||||
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
|
||||
.withColumn("v", explode(col('vs'))) \
|
||||
.drop('vs') \
|
||||
.withColumn('w', lit(1.0))
|
||||
|
||||
@property
|
||||
def python_plus_one(self):
|
||||
from pyspark.sql.functions import udf
|
||||
return udf(lambda v: v + 1, 'double')
|
||||
|
||||
@property
|
||||
def pandas_scalar_time_two(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
return pandas_udf(lambda v: v * 2, 'double')
|
||||
|
||||
@property
|
||||
def pandas_agg_mean_udf(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
||||
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
||||
def avg(v):
|
||||
return v.mean()
|
||||
return avg
|
||||
|
||||
@property
|
||||
def pandas_agg_max_udf(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
||||
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
||||
def max(v):
|
||||
return v.max()
|
||||
return max
|
||||
|
||||
@property
|
||||
def pandas_agg_min_udf(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
||||
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
||||
def min(v):
|
||||
return v.min()
|
||||
return min
|
||||
|
||||
@property
|
||||
def unbounded_window(self):
|
||||
return Window.partitionBy('id') \
|
||||
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
|
||||
@property
|
||||
def ordered_window(self):
|
||||
return Window.partitionBy('id').orderBy('v')
|
||||
|
||||
@property
|
||||
def unpartitioned_window(self):
|
||||
return Window.partitionBy()
|
||||
|
||||
def test_simple(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
|
||||
expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
|
||||
|
||||
result2 = df.select(mean_udf(df['v']).over(w))
|
||||
expected2 = df.select(mean(df['v']).over(w))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
||||
|
||||
def test_multiple_udfs(self):
|
||||
from pyspark.sql.functions import max, min, mean
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
|
||||
.withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
|
||||
.withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
|
||||
|
||||
expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
|
||||
.withColumn('max_v', max(df['v']).over(w)) \
|
||||
.withColumn('min_w', min(df['w']).over(w))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
|
||||
def test_replace_existing(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
|
||||
expected1 = df.withColumn('v', mean(df['v']).over(w))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
|
||||
def test_mixed_sql(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
|
||||
expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
|
||||
def test_mixed_udf(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
plus_one = self.python_plus_one
|
||||
time_two = self.pandas_scalar_time_two
|
||||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
result1 = df.withColumn(
|
||||
'v2',
|
||||
plus_one(mean_udf(plus_one(df['v'])).over(w)))
|
||||
expected1 = df.withColumn(
|
||||
'v2',
|
||||
plus_one(mean(plus_one(df['v'])).over(w)))
|
||||
|
||||
result2 = df.withColumn(
|
||||
'v2',
|
||||
time_two(mean_udf(time_two(df['v'])).over(w)))
|
||||
expected2 = df.withColumn(
|
||||
'v2',
|
||||
time_two(mean(time_two(df['v'])).over(w)))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
||||
|
||||
def test_without_partitionBy(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
df = self.data
|
||||
w = self.unpartitioned_window
|
||||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
|
||||
expected1 = df.withColumn('v2', mean(df['v']).over(w))
|
||||
|
||||
result2 = df.select(mean_udf(df['v']).over(w))
|
||||
expected2 = df.select(mean(df['v']).over(w))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
||||
|
||||
def test_mixed_sql_and_udf(self):
|
||||
from pyspark.sql.functions import max, min, rank, col
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
ow = self.ordered_window
|
||||
max_udf = self.pandas_agg_max_udf
|
||||
min_udf = self.pandas_agg_min_udf
|
||||
|
||||
result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
|
||||
expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
|
||||
|
||||
# Test mixing sql window function and window udf in the same expression
|
||||
result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
|
||||
expected2 = expected1
|
||||
|
||||
# Test chaining sql aggregate function and udf
|
||||
result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
|
||||
.withColumn('min_v', min(df['v']).over(w)) \
|
||||
.withColumn('v_diff', col('max_v') - col('min_v')) \
|
||||
.drop('max_v', 'min_v')
|
||||
expected3 = expected1
|
||||
|
||||
# Test mixing sql window function and udf
|
||||
result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
|
||||
.withColumn('rank', rank().over(ow))
|
||||
expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
|
||||
.withColumn('rank', rank().over(ow))
|
||||
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
|
||||
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
|
||||
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
|
||||
|
||||
def test_array_type(self):
|
||||
from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
|
||||
result1 = df.withColumn('v2', array_udf(df['v']).over(w))
|
||||
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
|
||||
|
||||
def test_invalid_args(self):
|
||||
from pyspark.sql.functions import mean, pandas_udf, PandasUDFType
|
||||
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
ow = self.ordered_window
|
||||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
AnalysisException,
|
||||
'.*not supported within a window function'):
|
||||
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
|
||||
df.withColumn('v2', foo_udf(df['v']).over(w))
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
AnalysisException,
|
||||
'.*Only unbounded window frame is supported.*'):
|
||||
df.withColumn('mean_v', mean_udf(df['v']).over(ow))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.sql.tests import *
|
||||
if xmlrunner:
|
||||
|
|
|
@ -128,6 +128,21 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
|
|||
return lambda *a: (wrapped(*a), arrow_return_type)
|
||||
|
||||
|
||||
def wrap_window_agg_pandas_udf(f, return_type):
|
||||
# This is similar to grouped_agg_pandas_udf, the only difference
|
||||
# is that window_agg_pandas_udf needs to repeat the return value
|
||||
# to match window length, where grouped_agg_pandas_udf just returns
|
||||
# the scalar value.
|
||||
arrow_return_type = to_arrow_type(return_type)
|
||||
|
||||
def wrapped(*series):
|
||||
import pandas as pd
|
||||
result = f(*series)
|
||||
return pd.Series([result]).repeat(len(series[0]))
|
||||
|
||||
return lambda *a: (wrapped(*a), arrow_return_type)
|
||||
|
||||
|
||||
def read_single_udf(pickleSer, infile, eval_type):
|
||||
num_arg = read_int(infile)
|
||||
arg_offsets = [read_int(infile) for i in range(num_arg)]
|
||||
|
@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type):
|
|||
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
|
||||
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
||||
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
|
||||
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
|
||||
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
|
||||
return arg_offsets, wrap_udf(func, return_type)
|
||||
else:
|
||||
|
@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type):
|
|||
|
||||
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
||||
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
|
||||
timezone = utf8_deserializer.loads(infile)
|
||||
ser = ArrowStreamPandasSerializer(timezone)
|
||||
else:
|
||||
|
|
|
@ -1739,9 +1739,10 @@ class Analyzer(
|
|||
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
|
||||
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
|
||||
* all regular expressions.
|
||||
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
|
||||
* 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
|
||||
* it into the plan tree.
|
||||
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s
|
||||
* and [[WindowFunctionType]]s.
|
||||
* 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a
|
||||
* [[Window]] operator and inserts it into the plan tree.
|
||||
*/
|
||||
object ExtractWindowExpressions extends Rule[LogicalPlan] {
|
||||
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
|
||||
|
@ -1901,7 +1902,7 @@ class Analyzer(
|
|||
s"Please file a bug report with this error message, stack trace, and the query.")
|
||||
} else {
|
||||
val spec = distinctWindowSpec.head
|
||||
(spec.partitionSpec, spec.orderSpec)
|
||||
(spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr))
|
||||
}
|
||||
}.toSeq
|
||||
|
||||
|
@ -1909,7 +1910,7 @@ class Analyzer(
|
|||
// setting this to the child of the next Window operator.
|
||||
val windowOps =
|
||||
groupedWindowExpressions.foldLeft(child) {
|
||||
case (last, ((partitionSpec, orderSpec), windowExpressions)) =>
|
||||
case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
|
||||
Window(windowExpressions, partitionSpec, orderSpec, last)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.api.python.PythonEvalType
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
|
||||
|
@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper {
|
|||
failAnalysis("An offset window function can only be evaluated in an ordered " +
|
||||
s"row-based window frame with a single offset: $w")
|
||||
|
||||
case _ @ WindowExpression(_: PythonUDF,
|
||||
WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
|
||||
if !frame.isUnbounded =>
|
||||
failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
|
||||
|
||||
case w @ WindowExpression(e, s) =>
|
||||
// Only allow window functions with an aggregate expression or an offset window
|
||||
// function.
|
||||
// function or a Pandas window UDF.
|
||||
e match {
|
||||
case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
|
||||
w
|
||||
case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
|
||||
w
|
||||
case _ =>
|
||||
failAnalysis(s"Expression '$e' not supported within a window function.")
|
||||
}
|
||||
|
@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper {
|
|||
|
||||
case Aggregate(groupingExprs, aggregateExprs, child) =>
|
||||
def isAggregateExpression(expr: Expression) = {
|
||||
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
|
||||
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
|
||||
}
|
||||
|
||||
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
|
||||
|
|
|
@ -34,10 +34,14 @@ object PythonUDF {
|
|||
e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
|
||||
}
|
||||
|
||||
def isGroupAggPandasUDF(e: Expression): Boolean = {
|
||||
def isGroupedAggPandasUDF(e: Expression): Boolean = {
|
||||
e.isInstanceOf[PythonUDF] &&
|
||||
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
|
||||
}
|
||||
|
||||
// This is currently same as GroupedAggPandasUDF, but we might support new types in the future,
|
||||
// e.g, N -> N transform.
|
||||
def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.util.Locale
|
|||
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -297,6 +297,37 @@ trait WindowFunction extends Expression {
|
|||
def frame: WindowFrame = UnspecifiedFrame
|
||||
}
|
||||
|
||||
/**
|
||||
* Case objects that describe whether a window function is a SQL window function or a Python
|
||||
* user-defined window function.
|
||||
*/
|
||||
sealed trait WindowFunctionType
|
||||
|
||||
object WindowFunctionType {
|
||||
case object SQL extends WindowFunctionType
|
||||
case object Python extends WindowFunctionType
|
||||
|
||||
def functionType(windowExpression: NamedExpression): WindowFunctionType = {
|
||||
val t = windowExpression.collectFirst {
|
||||
case _: WindowFunction | _: AggregateFunction => SQL
|
||||
case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
|
||||
}
|
||||
|
||||
// Normally a window expression would either have a SQL window function, a SQL
|
||||
// aggregate function or a python window UDF. However, sometimes the optimizer will replace
|
||||
// the window function if the value of the window function can be predetermined.
|
||||
// For example, for query:
|
||||
//
|
||||
// select count(NULL) over () from values 1.0, 2.0, 3.0 T(a)
|
||||
//
|
||||
// The window function will be replaced by expression literal(0)
|
||||
// To handle this case, if a window expression doesn't have a regular window function, we
|
||||
// consider its type to be SQL as literal(0) is also a SQL expression.
|
||||
t.getOrElse(SQL)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An offset window function is a window function that returns the value of the input column offset
|
||||
* by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with
|
||||
|
|
|
@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] {
|
|||
/**
|
||||
* Collapse Adjacent Window Expression.
|
||||
* - If the partition specs and order specs are the same and the window expression are
|
||||
* independent, collapse into the parent.
|
||||
* independent and are of the same window function type, collapse into the parent.
|
||||
*/
|
||||
object CollapseWindow extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
|
||||
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
|
||||
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty =>
|
||||
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
|
||||
// This assumes Window contains the same type of window expressions. This is ensured
|
||||
// by ExtractWindowFunctions.
|
||||
WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) =>
|
||||
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.planning
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
|
@ -215,7 +216,7 @@ object PhysicalAggregation {
|
|||
case agg: AggregateExpression
|
||||
if !equivalentAggregateExpressions.addExpr(agg) => agg
|
||||
case udf: PythonUDF
|
||||
if PythonUDF.isGroupAggPandasUDF(udf) &&
|
||||
if PythonUDF.isGroupedAggPandasUDF(udf) &&
|
||||
!equivalentAggregateExpressions.addExpr(udf) => udf
|
||||
}
|
||||
}
|
||||
|
@ -245,7 +246,7 @@ object PhysicalAggregation {
|
|||
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
|
||||
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
|
||||
// Similar to AggregateExpression
|
||||
case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
|
||||
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
|
||||
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
|
||||
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
|
||||
case expression =>
|
||||
|
@ -268,3 +269,40 @@ object PhysicalAggregation {
|
|||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An extractor used when planning physical execution of a window. This extractor outputs
|
||||
* the window function type of the logical window.
|
||||
*
|
||||
* The input logical window must contain same type of window functions, which is ensured by
|
||||
* the rule ExtractWindowExpressions in the analyzer.
|
||||
*/
|
||||
object PhysicalWindow {
|
||||
// windowFunctionType, windowExpression, partitionSpec, orderSpec, child
|
||||
private type ReturnType =
|
||||
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)
|
||||
|
||||
def unapply(a: Any): Option[ReturnType] = a match {
|
||||
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
|
||||
|
||||
// The window expression should not be empty here, otherwise it's a bug.
|
||||
if (windowExpressions.isEmpty) {
|
||||
throw new AnalysisException(s"Window expression is empty in $expr")
|
||||
}
|
||||
|
||||
val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
|
||||
.reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
|
||||
if (t1 != t2) {
|
||||
// We shouldn't have different window function type here, otherwise it's a bug.
|
||||
throw new AnalysisException(
|
||||
s"Found different window function type in $windowExpressions")
|
||||
} else {
|
||||
t1
|
||||
}
|
||||
}
|
||||
|
||||
Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))
|
||||
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ class SparkPlanner(
|
|||
DataSourceStrategy(conf) ::
|
||||
SpecialLimits ::
|
||||
Aggregation ::
|
||||
Window ::
|
||||
JoinSelection ::
|
||||
InMemoryScans ::
|
||||
BasicOperators :: Nil)
|
||||
|
|
|
@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case PhysicalAggregation(
|
||||
namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
|
||||
|
||||
if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
|
||||
if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) {
|
||||
throw new AnalysisException(
|
||||
"Streaming aggregation doesn't support group aggregate pandas UDF")
|
||||
}
|
||||
|
@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
object Window extends Strategy {
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case PhysicalWindow(
|
||||
WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
|
||||
execution.window.WindowExec(
|
||||
windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
|
||||
|
||||
case PhysicalWindow(
|
||||
WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) =>
|
||||
execution.python.WindowInPandasExec(
|
||||
windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
|
||||
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
|
||||
|
||||
object InMemoryScans extends Strategy {
|
||||
|
@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
|
||||
case e @ logical.Expand(_, _, child) =>
|
||||
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
|
||||
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
|
||||
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
|
||||
case logical.Sample(lb, ub, withReplacement, seed, child) =>
|
||||
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
|
||||
case logical.LocalRelation(output, data, _) =>
|
||||
|
|
|
@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
|
|||
*/
|
||||
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
|
||||
e.isInstanceOf[AggregateExpression] ||
|
||||
PythonUDF.isGroupAggPandasUDF(e) ||
|
||||
PythonUDF.isGroupedAggPandasUDF(e) ||
|
||||
agg.groupingExpressions.exists(_.semanticEquals(e))
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.python
|
||||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.{SparkEnv, TaskContext}
|
||||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.types.{DataType, StructField, StructType}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
case class WindowInPandasExec(
|
||||
windowExpression: Seq[NamedExpression],
|
||||
partitionSpec: Seq[Expression],
|
||||
orderSpec: Seq[SortOrder],
|
||||
child: SparkPlan) extends UnaryExecNode {
|
||||
|
||||
override def output: Seq[Attribute] =
|
||||
child.output ++ windowExpression.map(_.toAttribute)
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] = {
|
||||
if (partitionSpec.isEmpty) {
|
||||
// Only show warning when the number of bytes is larger than 100 MB?
|
||||
logWarning("No Partition Defined for Window operation! Moving all data to a single "
|
||||
+ "partition, this can cause serious performance degradation.")
|
||||
AllTuples :: Nil
|
||||
} else {
|
||||
ClusteredDistribution(partitionSpec) :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
|
||||
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
|
||||
|
||||
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
||||
|
||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
||||
|
||||
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
|
||||
udf.children match {
|
||||
case Seq(u: PythonUDF) =>
|
||||
val (chained, children) = collectFunctions(u)
|
||||
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
|
||||
case children =>
|
||||
// There should not be any other UDFs, or the children can't be evaluated directly.
|
||||
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
|
||||
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the resulting projection.
|
||||
*
|
||||
* This method uses Code Generation. It can only be used on the executor side.
|
||||
*
|
||||
* @param expressions unbound ordered function expressions.
|
||||
* @return the final resulting projection.
|
||||
*/
|
||||
private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
|
||||
val references = expressions.zipWithIndex.map { case (e, i) =>
|
||||
// Results of window expressions will be on the right side of child's output
|
||||
BoundReference(child.output.size + i, e.dataType, e.nullable)
|
||||
}
|
||||
val unboundToRefMap = expressions.zip(references).toMap
|
||||
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
|
||||
UnsafeProjection.create(
|
||||
child.output ++ patchedWindowExpression,
|
||||
child.output)
|
||||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val inputRDD = child.execute()
|
||||
|
||||
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
|
||||
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
|
||||
val sessionLocalTimeZone = conf.sessionLocalTimeZone
|
||||
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
|
||||
|
||||
// Extract window expressions and window functions
|
||||
val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
|
||||
|
||||
val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
|
||||
|
||||
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
|
||||
|
||||
// Filter child output attributes down to only those that are UDF inputs.
|
||||
// Also eliminate duplicate UDF inputs.
|
||||
val allInputs = new ArrayBuffer[Expression]
|
||||
val dataTypes = new ArrayBuffer[DataType]
|
||||
val argOffsets = inputs.map { input =>
|
||||
input.map { e =>
|
||||
if (allInputs.exists(_.semanticEquals(e))) {
|
||||
allInputs.indexWhere(_.semanticEquals(e))
|
||||
} else {
|
||||
allInputs += e
|
||||
dataTypes += e.dataType
|
||||
allInputs.length - 1
|
||||
}
|
||||
}.toArray
|
||||
}.toArray
|
||||
|
||||
// Schema of input rows to the python runner
|
||||
val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
|
||||
StructField(s"_$i", dt)
|
||||
})
|
||||
|
||||
inputRDD.mapPartitionsInternal { iter =>
|
||||
val context = TaskContext.get()
|
||||
|
||||
val grouped = if (partitionSpec.isEmpty) {
|
||||
// Use an empty unsafe row as a place holder for the grouping key
|
||||
Iterator((new UnsafeRow(), iter))
|
||||
} else {
|
||||
GroupedIterator(iter, partitionSpec, child.output)
|
||||
}
|
||||
|
||||
// The queue used to buffer input rows so we can drain it to
|
||||
// combine input with output from Python.
|
||||
val queue = HybridRowQueue(context.taskMemoryManager(),
|
||||
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
|
||||
context.addTaskCompletionListener { _ =>
|
||||
queue.close()
|
||||
}
|
||||
|
||||
val inputProj = UnsafeProjection.create(allInputs, child.output)
|
||||
val pythonInput = grouped.map { case (_, rows) =>
|
||||
rows.map { row =>
|
||||
queue.add(row.asInstanceOf[UnsafeRow])
|
||||
inputProj(row)
|
||||
}
|
||||
}
|
||||
|
||||
val windowFunctionResult = new ArrowPythonRunner(
|
||||
pyFuncs, bufferSize, reuseWorker,
|
||||
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
|
||||
argOffsets, windowInputSchema,
|
||||
sessionLocalTimeZone, pandasRespectSessionTimeZone)
|
||||
.compute(pythonInput, context.partitionId(), context)
|
||||
|
||||
val joined = new JoinedRow
|
||||
val resultProj = createResultProjection(expressions)
|
||||
|
||||
windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
|
||||
val leftRow = queue.remove()
|
||||
val joinedRow = joined(leftRow, windowOutput)
|
||||
resultProj(joinedRow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue