[SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames

## What changes were proposed in this pull request?
This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions.

```
       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
       >>> from pyspark.sql import Window
       >>> df = spark.createDataFrame(
       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
       ...     ("id", "v"))
       >>> pandas_udf("double", PandasUDFType.GROUPED_AGG)
       ... def mean_udf(v):
       ...     return v.mean()
       >>> w = Window.partitionBy('id')
       >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()
       +---+----+------+
       | id|   v|mean_v|
       +---+----+------+
       |  1| 1.0|   1.5|
       |  1| 2.0|   1.5|
       |  2| 3.0|   6.0|
       |  2| 5.0|   6.0|
       |  2|10.0|   6.0|
       +---+----+------+
```

The scope of this PR is somewhat limited in terms of:
(1) Only supports unbounded window, which acts essentially as group by.
(2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping)

Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR.

## How was this patch tested?

WindowPandasUDFTests

Author: Li Jin <ice.xelloss@gmail.com>

Closes #21082 from icexelloss/SPARK-22239-window-udf.
This commit is contained in:
Li Jin 2018-06-13 09:10:52 +08:00 committed by hyukjinkwon
parent f53818d35b
commit 9786ce66c5
15 changed files with 580 additions and 22 deletions

View file

@ -40,6 +40,7 @@ private[spark] object PythonEvalType {
val SQL_SCALAR_PANDAS_UDF = 200
val SQL_GROUPED_MAP_PANDAS_UDF = 201
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
@ -47,6 +48,7 @@ private[spark] object PythonEvalType {
case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
}
}

View file

@ -74,6 +74,7 @@ class PythonEvalType(object):
SQL_SCALAR_PANDAS_UDF = 200
SQL_GROUPED_MAP_PANDAS_UDF = 201
SQL_GROUPED_AGG_PANDAS_UDF = 202
SQL_WINDOW_AGG_PANDAS_UDF = 203
def portable_hash(x):

View file

@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None):
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
:class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
output types.
:class:`MapType` and :class:`StructType` are currently not supported as output types.
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`
Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
:class:`pyspark.sql.Window`
This example shows using grouped aggregated UDFs with groupby:
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 6.0|
+---+-----------+
.. seealso:: :meth:`pyspark.sql.GroupedData.agg`
This example shows using grouped aggregated UDFs as window functions. Note that only
unbounded window frame is supported at the moment:
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql import Window
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> w = Window.partitionBy('id') \\
... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
+---+----+------+
| 1| 1.0| 1.5|
| 1| 2.0| 1.5|
| 2| 3.0| 6.0|
| 2| 5.0| 6.0|
| 2|10.0| 6.0|
+---+----+------+
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked

View file

@ -5454,6 +5454,15 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_array_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
from pyspark.sql.functions import mean
@ -5479,6 +5488,235 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class WindowPandasUDFTests(ReusedSQLTestCase):
@property
def data(self):
from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
.drop('vs') \
.withColumn('w', lit(1.0))
@property
def python_plus_one(self):
from pyspark.sql.functions import udf
return udf(lambda v: v + 1, 'double')
@property
def pandas_scalar_time_two(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
return pandas_udf(lambda v: v * 2, 'double')
@property
def pandas_agg_mean_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
return avg
@property
def pandas_agg_max_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def max(v):
return v.max()
return max
@property
def pandas_agg_min_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def min(v):
return v.min()
return min
@property
def unbounded_window(self):
return Window.partitionBy('id') \
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
@property
def ordered_window(self):
return Window.partitionBy('id').orderBy('v')
@property
def unpartitioned_window(self):
return Window.partitionBy()
def test_simple(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max
df = self.data
w = self.unbounded_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
result2 = df.select(mean_udf(df['v']).over(w))
expected2 = df.select(mean(df['v']).over(w))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_multiple_udfs(self):
from pyspark.sql.functions import max, min, mean
df = self.data
w = self.unbounded_window
result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
.withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
.withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
.withColumn('max_v', max(df['v']).over(w)) \
.withColumn('min_w', min(df['w']).over(w))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_replace_existing(self):
from pyspark.sql.functions import mean
df = self.data
w = self.unbounded_window
result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
expected1 = df.withColumn('v', mean(df['v']).over(w))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_mixed_sql(self):
from pyspark.sql.functions import mean
df = self.data
w = self.unbounded_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_mixed_udf(self):
from pyspark.sql.functions import mean
df = self.data
w = self.unbounded_window
plus_one = self.python_plus_one
time_two = self.pandas_scalar_time_two
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn(
'v2',
plus_one(mean_udf(plus_one(df['v'])).over(w)))
expected1 = df.withColumn(
'v2',
plus_one(mean(plus_one(df['v'])).over(w)))
result2 = df.withColumn(
'v2',
time_two(mean_udf(time_two(df['v'])).over(w)))
expected2 = df.withColumn(
'v2',
time_two(mean(time_two(df['v'])).over(w)))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_without_partitionBy(self):
from pyspark.sql.functions import mean
df = self.data
w = self.unpartitioned_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
expected1 = df.withColumn('v2', mean(df['v']).over(w))
result2 = df.select(mean_udf(df['v']).over(w))
expected2 = df.select(mean(df['v']).over(w))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_mixed_sql_and_udf(self):
from pyspark.sql.functions import max, min, rank, col
df = self.data
w = self.unbounded_window
ow = self.ordered_window
max_udf = self.pandas_agg_max_udf
min_udf = self.pandas_agg_min_udf
result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
# Test mixing sql window function and window udf in the same expression
result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
expected2 = expected1
# Test chaining sql aggregate function and udf
result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
.withColumn('min_v', min(df['v']).over(w)) \
.withColumn('v_diff', col('max_v') - col('min_v')) \
.drop('max_v', 'min_v')
expected3 = expected1
# Test mixing sql window function and udf
result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
.withColumn('rank', rank().over(ow))
expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
.withColumn('rank', rank().over(ow))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
def test_array_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
w = self.unbounded_window
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.withColumn('v2', array_udf(df['v']).over(w))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
from pyspark.sql.functions import mean, pandas_udf, PandasUDFType
df = self.data
w = self.unbounded_window
ow = self.ordered_window
mean_udf = self.pandas_agg_mean_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
AnalysisException,
'.*not supported within a window function'):
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
df.withColumn('v2', foo_udf(df['v']).over(w))
with QuietTest(self.sc):
with self.assertRaisesRegexp(
AnalysisException,
'.*Only unbounded window frame is supported.*'):
df.withColumn('mean_v', mean_udf(df['v']).over(ow))
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:

View file

@ -128,6 +128,21 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
return lambda *a: (wrapped(*a), arrow_return_type)
def wrap_window_agg_pandas_udf(f, return_type):
# This is similar to grouped_agg_pandas_udf, the only difference
# is that window_agg_pandas_udf needs to repeat the return value
# to match window length, where grouped_agg_pandas_udf just returns
# the scalar value.
arrow_return_type = to_arrow_type(return_type)
def wrapped(*series):
import pandas as pd
result = f(*series)
return pd.Series([result]).repeat(len(series[0]))
return lambda *a: (wrapped(*a), arrow_return_type)
def read_single_udf(pickleSer, infile, eval_type):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type):
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(func, return_type)
else:
@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type):
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
timezone = utf8_deserializer.loads(infile)
ser = ArrowStreamPandasSerializer(timezone)
else:

View file

@ -1739,9 +1739,10 @@ class Analyzer(
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
* all regular expressions.
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
* 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
* it into the plan tree.
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s
* and [[WindowFunctionType]]s.
* 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a
* [[Window]] operator and inserts it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
@ -1901,7 +1902,7 @@ class Analyzer(
s"Please file a bug report with this error message, stack trace, and the query.")
} else {
val spec = distinctWindowSpec.head
(spec.partitionSpec, spec.orderSpec)
(spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr))
}
}.toSeq
@ -1909,7 +1910,7 @@ class Analyzer(
// setting this to the child of the next Window operator.
val windowOps =
groupedWindowExpressions.foldLeft(child) {
case (last, ((partitionSpec, orderSpec), windowExpressions)) =>
case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
Window(windowExpressions, partitionSpec, orderSpec, last)
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("An offset window function can only be evaluated in an ordered " +
s"row-based window frame with a single offset: $w")
case _ @ WindowExpression(_: PythonUDF,
WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
if !frame.isUnbounded =>
failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
// function.
// function or a Pandas window UDF.
e match {
case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
w
case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
w
case _ =>
failAnalysis(s"Expression '$e' not supported within a window function.")
}
@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression) = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}
def checkValidAggregateExpression(expr: Expression): Unit = expr match {

View file

@ -34,10 +34,14 @@ object PythonUDF {
e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
}
def isGroupAggPandasUDF(e: Expression): Boolean = {
def isGroupedAggPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}
// This is currently same as GroupedAggPandasUDF, but we might support new types in the future,
// e.g, N -> N transform.
def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e)
}
/**

View file

@ -21,7 +21,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp}
import org.apache.spark.sql.types._
/**
@ -297,6 +297,37 @@ trait WindowFunction extends Expression {
def frame: WindowFrame = UnspecifiedFrame
}
/**
* Case objects that describe whether a window function is a SQL window function or a Python
* user-defined window function.
*/
sealed trait WindowFunctionType
object WindowFunctionType {
case object SQL extends WindowFunctionType
case object Python extends WindowFunctionType
def functionType(windowExpression: NamedExpression): WindowFunctionType = {
val t = windowExpression.collectFirst {
case _: WindowFunction | _: AggregateFunction => SQL
case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
}
// Normally a window expression would either have a SQL window function, a SQL
// aggregate function or a python window UDF. However, sometimes the optimizer will replace
// the window function if the value of the window function can be predetermined.
// For example, for query:
//
// select count(NULL) over () from values 1.0, 2.0, 3.0 T(a)
//
// The window function will be replaced by expression literal(0)
// To handle this case, if a window expression doesn't have a regular window function, we
// consider its type to be SQL as literal(0) is also a SQL expression.
t.getOrElse(SQL)
}
}
/**
* An offset window function is a window function that returns the value of the input column offset
* by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with

View file

@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] {
/**
* Collapse Adjacent Window Expression.
* - If the partition specs and order specs are the same and the window expression are
* independent, collapse into the parent.
* independent and are of the same window function type, collapse into the parent.
*/
object CollapseWindow extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty =>
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
// This assumes Window contains the same type of window expressions. This is ensured
// by ExtractWindowFunctions.
WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
@ -215,7 +216,7 @@ object PhysicalAggregation {
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupAggPandasUDF(udf) &&
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}
@ -245,7 +246,7 @@ object PhysicalAggregation {
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
// Similar to AggregateExpression
case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
case expression =>
@ -268,3 +269,40 @@ object PhysicalAggregation {
case _ => None
}
}
/**
* An extractor used when planning physical execution of a window. This extractor outputs
* the window function type of the logical window.
*
* The input logical window must contain same type of window functions, which is ensured by
* the rule ExtractWindowExpressions in the analyzer.
*/
object PhysicalWindow {
// windowFunctionType, windowExpression, partitionSpec, orderSpec, child
private type ReturnType =
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
// The window expression should not be empty here, otherwise it's a bug.
if (windowExpressions.isEmpty) {
throw new AnalysisException(s"Window expression is empty in $expr")
}
val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
.reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
if (t1 != t2) {
// We shouldn't have different window function type here, otherwise it's a bug.
throw new AnalysisException(
s"Found different window function type in $windowExpressions")
} else {
t1
}
}
Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))
case _ => None
}
}

View file

@ -41,6 +41,7 @@ class SparkPlanner(
DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
Window ::
JoinSelection ::
InMemoryScans ::
BasicOperators :: Nil)

View file

@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case PhysicalAggregation(
namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) {
throw new AnalysisException(
"Streaming aggregation doesn't support group aggregate pandas UDF")
}
@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
object Window extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalWindow(
WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
execution.window.WindowExec(
windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case PhysicalWindow(
WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) =>
execution.python.WindowInPandasExec(
windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case _ => Nil
}
}
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object InMemoryScans extends Strategy {
@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>

View file

@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupAggPandasUDF(e) ||
PythonUDF.isGroupedAggPandasUDF(e) ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}

View file

@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.python
import java.io.File
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
case class WindowInPandasExec(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)
override def requiredChildDistribution: Seq[Distribution] = {
if (partitionSpec.isEmpty) {
// Only show warning when the number of bytes is larger than 100 MB?
logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ "partition, this can cause serious performance degradation.")
AllTuples :: Nil
} else {
ClusteredDistribution(partitionSpec) :: Nil
}
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputPartitioning: Partitioning = child.outputPartitioning
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}
/**
* Create the resulting projection.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param expressions unbound ordered function expressions.
* @return the final resulting projection.
*/
private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
val references = expressions.zipWithIndex.map { case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
child.output)
}
protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute()
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
// Extract window expressions and window functions
val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
// Filter child output attributes down to only those that are UDF inputs.
// Also eliminate duplicate UDF inputs.
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
// Schema of input rows to the python runner
val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})
inputRDD.mapPartitionsInternal { iter =>
val context = TaskContext.get()
val grouped = if (partitionSpec.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, partitionSpec, child.output)
}
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
context.addTaskCompletionListener { _ =>
queue.close()
}
val inputProj = UnsafeProjection.create(allInputs, child.output)
val pythonInput = grouped.map { case (_, rows) =>
rows.map { row =>
queue.add(row.asInstanceOf[UnsafeRow])
inputProj(row)
}
}
val windowFunctionResult = new ArrowPythonRunner(
pyFuncs, bufferSize, reuseWorker,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
argOffsets, windowInputSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow
val resultProj = createResultProjection(expressions)
windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, windowOutput)
resultProj(joinedRow)
}
}
}
}