[SPARK-24561][SQL][PYTHON] User-defined window aggregation functions with Pandas UDF (bounded window)

## What changes were proposed in this pull request?

This PR implements a new feature - window aggregation Pandas UDF for bounded window.

#### Doc:
https://docs.google.com/document/d/14EjeY5z4-NC27-SmIP9CsMPCANeTcvxN44a7SIJtZPc/edit#heading=h.c87w44wcj3wj

#### Example:
```
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.window import Window

df = spark.range(0, 10, 2).toDF('v')
w1 = Window.partitionBy().orderBy('v').rangeBetween(-2, 4)
w2 = Window.partitionBy().orderBy('v').rowsBetween(-2, 2)

pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
    return v.mean()

df.withColumn('v_mean', avg(df['v']).over(w1)).show()
# +---+------+
# |  v|v_mean|
# +---+------+
# |  0|   1.0|
# |  2|   2.0|
# |  4|   4.0|
# |  6|   6.0|
# |  8|   7.0|
# +---+------+

df.withColumn('v_mean', avg(df['v']).over(w2)).show()
# +---+------+
# |  v|v_mean|
# +---+------+
# |  0|   2.0|
# |  2|   3.0|
# |  4|   4.0|
# |  6|   5.0|
# |  8|   6.0|
# +---+------+

```

#### High level changes:

This PR modifies the existing WindowInPandasExec physical node to deal with unbounded (growing, shrinking and sliding) windows.

* `WindowInPandasExec` now share the same base class as `WindowExec` and share utility functions. See `WindowExecBase`
* `WindowFunctionFrame` now has two new functions `currentLowerBound` and `currentUpperBound` - to return the lower and upper window bound for the current output row. It is also modified to allow `AggregateProcessor` == null. Null aggregator processor is used for `WindowInPandasExec` where we don't have an aggregator and only uses lower and upper bound functions from `WindowFunctionFrame`
* The biggest change is in `WindowInPandasExec`, where it is modified to take `currentLowerBound` and `currentUpperBound` and write those values together with the input data to the python process for rolling window aggregation. See `WindowInPandasExec` for more details.

#### Discussion
In benchmarking, I found numpy variant of the rolling window UDF is much faster than the pandas version:

Spark SQL window function: 20s
Pandas variant: ~80s
Numpy variant: 10s
Numpy variant with numba: 4s

Allowing numpy variant of the vectorized UDFs is something I want to discuss because of the performance improvement, but doesn't have to be in this PR.

## How was this patch tested?

New tests

Closes #22305 from icexelloss/SPARK-24561-bounded-window-udf.

Authored-by: Li Jin <ice.xelloss@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Li Jin 2018-12-18 09:15:21 +08:00 committed by Hyukjin Kwon
parent 114d0de14c
commit 86100df54b
8 changed files with 794 additions and 306 deletions

View file

@ -2982,8 +2982,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 6.0|
+---+-----------+
This example shows using grouped aggregated UDFs as window functions. Note that only
unbounded window frame is supported at the moment:
This example shows using grouped aggregated UDFs as window functions.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql import Window
@ -2993,20 +2992,24 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> w = Window \\
... .partitionBy('id') \\
... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
>>> w = (Window.partitionBy('id')
... .orderBy('v')
... .rowsBetween(-1, 0))
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
+---+----+------+
| 1| 1.0| 1.5|
| 1| 1.0| 1.0|
| 1| 2.0| 1.5|
| 2| 3.0| 6.0|
| 2| 5.0| 6.0|
| 2|10.0| 6.0|
| 2| 3.0| 3.0|
| 2| 5.0| 4.0|
| 2|10.0| 7.5|
+---+----+------+
.. note:: For performance reasons, the input series to window functions are not copied.
Therefore, mutating the input series is not allowed and will cause incorrect results.
For the same reason, users should also not rely on the index of the input series.
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
.. note:: The user-defined functions are considered deterministic by default. Due to

View file

@ -46,6 +46,15 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
def pandas_scalar_time_two(self):
return pandas_udf(lambda v: v * 2, 'double')
@property
def pandas_agg_count_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('long', PandasUDFType.GROUPED_AGG)
def count(v):
return len(v)
return count
@property
def pandas_agg_mean_udf(self):
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
@ -70,7 +79,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
@property
def unbounded_window(self):
return Window.partitionBy('id') \
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v')
@property
def ordered_window(self):
@ -80,6 +89,32 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
def unpartitioned_window(self):
return Window.partitionBy()
@property
def sliding_row_window(self):
return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1)
@property
def sliding_range_window(self):
return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4)
@property
def growing_row_window(self):
return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3)
@property
def growing_range_window(self):
return Window.partitionBy('id').orderBy('v') \
.rangeBetween(Window.unboundedPreceding, 4)
@property
def shrinking_row_window(self):
return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing)
@property
def shrinking_range_window(self):
return Window.partitionBy('id').orderBy('v') \
.rangeBetween(-3, Window.unboundedFollowing)
def test_simple(self):
df = self.data
w = self.unbounded_window
@ -210,8 +245,6 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
def test_invalid_args(self):
df = self.data
w = self.unbounded_window
ow = self.ordered_window
mean_udf = self.pandas_agg_mean_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
@ -220,11 +253,101 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
df.withColumn('v2', foo_udf(df['v']).over(w))
with QuietTest(self.sc):
with self.assertRaisesRegexp(
AnalysisException,
'.*Only unbounded window frame is supported.*'):
df.withColumn('mean_v', mean_udf(df['v']).over(ow))
def test_bounded_simple(self):
from pyspark.sql.functions import mean, max, min, count
df = self.data
w1 = self.sliding_row_window
w2 = self.shrinking_range_window
plus_one = self.python_plus_one
count_udf = self.pandas_agg_count_udf
mean_udf = self.pandas_agg_mean_udf
max_udf = self.pandas_agg_max_udf
min_udf = self.pandas_agg_min_udf
result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \
.withColumn('count_v', count_udf(df['v']).over(w2)) \
.withColumn('max_v', max_udf(df['v']).over(w2)) \
.withColumn('min_v', min_udf(df['v']).over(w1))
expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \
.withColumn('count_v', count(df['v']).over(w2)) \
.withColumn('max_v', max(df['v']).over(w2)) \
.withColumn('min_v', min(df['v']).over(w1))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_growing_window(self):
from pyspark.sql.functions import mean
df = self.data
w1 = self.growing_row_window
w2 = self.growing_range_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
.withColumn('m2', mean_udf(df['v']).over(w2))
expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
.withColumn('m2', mean(df['v']).over(w2))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_sliding_window(self):
from pyspark.sql.functions import mean
df = self.data
w1 = self.sliding_row_window
w2 = self.sliding_range_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
.withColumn('m2', mean_udf(df['v']).over(w2))
expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
.withColumn('m2', mean(df['v']).over(w2))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_shrinking_window(self):
from pyspark.sql.functions import mean
df = self.data
w1 = self.shrinking_row_window
w2 = self.shrinking_range_window
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
.withColumn('m2', mean_udf(df['v']).over(w2))
expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
.withColumn('m2', mean(df['v']).over(w2))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_bounded_mixed(self):
from pyspark.sql.functions import mean, max
df = self.data
w1 = self.sliding_row_window
w2 = self.unbounded_window
mean_udf = self.pandas_agg_mean_udf
max_udf = self.pandas_agg_max_udf
result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \
.withColumn('max_v', max_udf(df['v']).over(w2)) \
.withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1))
expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \
.withColumn('max_v', max(df['v']).over(w2)) \
.withColumn('mean_unbounded_v', mean(df['v']).over(w1))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
if __name__ == "__main__":

View file

@ -145,7 +145,18 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
return lambda *a: (wrapped(*a), arrow_return_type)
def wrap_window_agg_pandas_udf(f, return_type):
def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
window_bound_types_str = runner_conf.get('pandas_window_bound_types')
window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index]
if window_bound_type == 'bounded':
return wrap_bounded_window_agg_pandas_udf(f, return_type)
elif window_bound_type == 'unbounded':
return wrap_unbounded_window_agg_pandas_udf(f, return_type)
else:
raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type))
def wrap_unbounded_window_agg_pandas_udf(f, return_type):
# This is similar to grouped_agg_pandas_udf, the only difference
# is that window_agg_pandas_udf needs to repeat the return value
# to match window length, where grouped_agg_pandas_udf just returns
@ -160,7 +171,41 @@ def wrap_window_agg_pandas_udf(f, return_type):
return lambda *a: (wrapped(*a), arrow_return_type)
def read_single_udf(pickleSer, infile, eval_type, runner_conf):
def wrap_bounded_window_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def wrapped(begin_index, end_index, *series):
import pandas as pd
result = []
# Index operation is faster on np.ndarray,
# So we turn the index series into np array
# here for performance
begin_array = begin_index.values
end_array = end_index.values
for i in range(len(begin_array)):
# Note: Create a slice from a series for each window is
# actually pretty expensive. However, there
# is no easy way to reduce cost here.
# Note: s.iloc[i : j] is about 30% faster than s[i: j], with
# the caveat that the created slices shares the same
# memory with s. Therefore, user are not allowed to
# change the value of input series inside the window
# function. It is rare that user needs to modify the
# input series in the window function, and therefore,
# it is be a reasonable restriction.
# Note: Calling reset_index on the slices will increase the cost
# of creating slices by about 100%. Therefore, for performance
# reasons we don't do it here.
series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series]
result.append(f(*series_slices))
return pd.Series(result)
return lambda *a: (wrapped(*a), arrow_return_type)
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
@ -184,7 +229,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf):
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(func, return_type)
else:
@ -226,7 +271,8 @@ def read_udfs(pickleSer, infile, eval_type):
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
split_offset = arg_offsets[0] + 1
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
@ -238,7 +284,8 @@ def read_udfs(pickleSer, infile, eval_type):
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=i)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))

View file

@ -134,11 +134,6 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("An offset window function can only be evaluated in an ordered " +
s"row-based window frame with a single offset: $w")
case _ @ WindowExpression(_: PythonUDF,
WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
if !frame.isUnbounded =>
failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
// function or a Pandas window UDF.

View file

@ -27,17 +27,64 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* This class calculates and outputs windowed aggregates over the rows in a single partition.
*
* This is similar to [[WindowExec]]. The main difference is that this node does not compute
* any window aggregation values. Instead, it computes the lower and upper bound for each window
* (i.e. window bounds) and pass the data and indices to Python worker to do the actual window
* aggregation.
*
* It currently materializes all data associated with the same partition key and passes them to
* Python worker. This is not strictly necessary for sliding windows and can be improved (by
* possibly slicing data into overlapping chunks and stitching them together).
*
* This class groups window expressions by their window boundaries so that window expressions
* with the same window boundaries can share the same window bounds. The window bounds are
* prepended to the data passed to the python worker.
*
* For example, if we have:
* avg(v) over specifiedwindowframe(RowFrame, -5, 5),
* avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing),
* avg(v) over specifiedwindowframe(RowFrame, -3, 3),
* max(v) over specifiedwindowframe(RowFrame, -3, 3)
*
* The python input will look like:
* (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v)
*
* where w1 is specifiedwindowframe(RowFrame, -5, 5)
* w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing)
* w3 is specifiedwindowframe(RowFrame, -3, 3)
*
* Note that w2 doesn't have bound indices in the python input because it's unbounded window
* so it's bound indices will always be the same.
*
* Bounded window and Unbounded window are evaluated differently in Python worker:
* (1) Bounded window takes the window bound indices in addition to the input columns.
* Unbounded window takes only input columns.
* (2) Bounded window evaluates the udf once per input row.
* Unbounded window evaluates the udf once per window partition.
* This is controlled by Python runner conf "pandas_window_bound_types"
*
* The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with
* [[WindowExec]]
*
* Note this doesn't support partial aggregation and all aggregation is computed from the entire
* window.
*/
case class WindowInPandasExec(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan) extends UnaryExecNode {
child: SparkPlan)
extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)
@ -60,6 +107,26 @@ case class WindowInPandasExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
/**
* Helper functions and data structures for window bounds
*
* It contains:
* (1) Total number of window bound indices in the python input row
* (2) Function from frame index to its lower bound column index in the python input row
* (3) Function from frame index to its upper bound column index in the python input row
* (4) Seq from frame index to its window bound type
*/
private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType])
/**
* Enum for window bound types. Used only inside this class.
*/
private sealed case class WindowBoundType(value: String)
private object UnboundedWindow extends WindowBoundType("unbounded")
private object BoundedWindow extends WindowBoundType("bounded")
private val windowBoundTypeConf = "pandas_window_bound_types"
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
@ -73,68 +140,150 @@ case class WindowInPandasExec(
}
/**
* Create the resulting projection.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param expressions unbound ordered function expressions.
* @return the final resulting projection.
* See [[WindowBoundHelpers]] for details.
*/
private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
val references = expressions.zipWithIndex.map { case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
private def computeWindowBoundHelpers(
factories: Seq[InternalRow => WindowFunctionFrame]
): WindowBoundHelpers = {
val functionFrames = factories.map(_(EmptyRow))
val windowBoundTypes = functionFrames.map {
case _: UnboundedWindowFunctionFrame => UnboundedWindow
case _: UnboundedFollowingWindowFunctionFrame |
_: SlidingWindowFunctionFrame |
_: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
// It should be impossible to get other types of window function frame here
case frame => throw new RuntimeException(s"Unexpected window function frame $frame.")
}
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
child.output)
val requiredIndices = functionFrames.map {
case _: UnboundedWindowFunctionFrame => 0
case _ => 2
}
val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) =>
if (num == 0) {
// Sentinel values for unbounded window
(-1, -1)
} else {
(upperBoundIndex - 2, upperBoundIndex - 1)
}
}
def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
(requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes)
}
protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute()
// Unwrap the expressions and factories from the map.
val expressionsWithFrameIndex =
windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex))
}
val expressions = expressionsWithFrameIndex.map(_._1)
val expressionIndexToFrameIndex =
expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
// Helper functions
val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) =
computeWindowBoundHelpers(factories)
val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 }
val numFrames = factories.length
val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
val spillThreshold = conf.windowExecBufferSpillThreshold
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
// Extract window expressions and window functions
val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e })
val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF])
// We shouldn't be chaining anything here.
// All chained python functions should only contain one function.
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
require(pyFuncs.length == expressions.length)
val udfWindowBoundTypes = pyFuncs.indices.map(i =>
frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf)
+ (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(",")))
// Filter child output attributes down to only those that are UDF inputs.
// Also eliminate duplicate UDF inputs.
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
// Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node
// handles UDF inputs.
val dataInputs = new ArrayBuffer[Expression]
val dataInputTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
if (dataInputs.exists(_.semanticEquals(e))) {
dataInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
dataInputs += e
dataInputTypes += e.dataType
dataInputs.length - 1
}
}.toArray
}.toArray
// Schema of input rows to the python runner
val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})
// In addition to UDF inputs, we will prepend window bounds for each UDFs.
// For bounded windows, we prepend lower bound and upper bound. For unbounded windows,
// we no not add window bounds. (strictly speaking, we only need to lower or upper bound
// if the window is bounded only on one side, this can be improved in the future)
inputRDD.mapPartitionsInternal { iter =>
// Setting window bounds for each window frames. Each window frame has different bounds so
// each has its own window bound columns.
val windowBoundsInput = factories.indices.flatMap { frameIndex =>
if (isBounded(frameIndex)) {
Seq(
BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false),
BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false)
)
} else {
Seq.empty
}
}
// Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset
// for the UDF is (lowerBoundOffet, upperBoundOffset, inputOffset1, inputOffset2, ...)
// For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...)
pyFuncs.indices.foreach { exprIndex =>
val frameIndex = expressionIndexToFrameIndex(exprIndex)
if (isBounded(frameIndex)) {
argOffsets(exprIndex) =
Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
argOffsets(exprIndex).map(_ + windowBoundsInput.length)
} else {
argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length)
}
}
val allInputs = windowBoundsInput ++ dataInputs
val allInputTypes = allInputs.map(_.dataType)
// Start processing.
child.execute().mapPartitions { iter =>
val context = TaskContext.get()
val grouped = if (partitionSpec.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, partitionSpec, child.output)
// Get all relevant projections.
val resultProj = createResultProjection(expressions)
val pythonInputProj = UnsafeProjection.create(
allInputs,
windowBoundsInput.map(ref =>
AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output
)
val pythonInputSchema = StructType(
allInputTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
}
)
val grouping = UnsafeProjection.create(partitionSpec, child.output)
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
@ -144,11 +293,94 @@ case class WindowInPandasExec(
queue.close()
}
val inputProj = UnsafeProjection.create(allInputs, child.output)
val pythonInput = grouped.map { case (_, rows) =>
rows.map { row =>
val stream = iter.map { row =>
queue.add(row.asInstanceOf[UnsafeRow])
inputProj(row)
row
}
val pythonInput = new Iterator[Iterator[UnsafeRow]] {
// Manage the stream and the grouping.
var nextRow: UnsafeRow = null
var nextGroup: UnsafeRow = null
var nextRowAvailable: Boolean = false
private[this] def fetchNextRow() {
nextRowAvailable = stream.hasNext
if (nextRowAvailable) {
nextRow = stream.next().asInstanceOf[UnsafeRow]
nextGroup = grouping(nextRow)
} else {
nextRow = null
nextGroup = null
}
}
fetchNextRow()
// Manage the current partition.
val buffer: ExternalAppendOnlyUnsafeRowArray =
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
var bufferIterator: Iterator[UnsafeRow] = _
val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
val frames = factories.map(_(indexRow))
private[this] def fetchNextPartition() {
// Collect all the rows in the current partition.
// Before we start to fetch new input rows, make a copy of nextGroup.
val currentGroup = nextGroup.copy()
// clear last partition
buffer.clear()
while (nextRowAvailable && nextGroup == currentGroup) {
buffer.add(nextRow)
fetchNextRow()
}
// Setup the frames.
var i = 0
while (i < numFrames) {
frames(i).prepare(buffer)
i += 1
}
// Setup iteration
rowIndex = 0
bufferIterator = buffer.generateIterator()
}
// Iteration
var rowIndex = 0
override final def hasNext: Boolean =
(bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
override final def next(): Iterator[UnsafeRow] = {
// Load the next partition if we need to.
if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
fetchNextPartition()
}
val join = new JoinedRow
bufferIterator.zipWithIndex.map {
case (current, index) =>
var frameIndex = 0
while (frameIndex < numFrames) {
frames(frameIndex).write(index, current)
// If the window is unbounded we don't need to write out window bounds.
if (isBounded(frameIndex)) {
indexRow.setInt(
lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound())
indexRow.setInt(
upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound())
}
frameIndex += 1
}
pythonInputProj(join(indexRow, current))
}
}
}
@ -156,12 +388,11 @@ case class WindowInPandasExec(
pyFuncs,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
argOffsets,
windowInputSchema,
pythonInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow
val resultProj = createResultProjection(expressions)
windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
val leftRow = queue.remove()

View file

@ -83,7 +83,7 @@ case class WindowExec(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
extends UnaryExecNode {
extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)
@ -104,193 +104,6 @@ case class WindowExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
/**
* Create a bound ordering object for a given frame type and offset. A bound ordering object is
* used to determine which input row lies within the frame boundaries of an output row.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param frame to evaluate. This can either be a Row or Range frame.
* @param bound with respect to the row.
* @param timeZone the session local timezone for time related calculations.
* @return a bound ordering object.
*/
private[this] def createBoundOrdering(
frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
(frame, bound) match {
case (RowFrame, CurrentRow) =>
RowBoundOrdering(0)
case (RowFrame, IntegerLiteral(offset)) =>
RowBoundOrdering(offset)
case (RangeFrame, CurrentRow) =>
val ordering = newOrdering(orderSpec, child.output)
RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
// Use only the first order expression when the offset is non-null.
val sortExpr = orderSpec.head
val expr = sortExpr.child
// Create the projection which returns the current 'value'.
val current = newMutableProjection(expr :: Nil, child.output)
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
case Descending => UnaryMinus(offset)
case Ascending => offset
}
// Create the projection which returns the current 'value' modified by adding the offset.
val boundExpr = (expr.dataType, boundOffset.dataType) match {
case (DateType, IntegerType) => DateAdd(expr, boundOffset)
case (TimestampType, CalendarIntervalType) =>
TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a== b => Add(expr, boundOffset)
}
val bound = newMutableProjection(boundExpr :: Nil, child.output)
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
val ordering = newOrdering(boundSortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case (RangeFrame, _) =>
sys.error("Non-Zero range offsets are not supported for windows " +
"with multiple order expressions.")
}
}
/**
* Collection containing an entry for each window frame to process. Each entry contains a frame's
* [[WindowExpression]]s and factory function for the WindowFrameFunction.
*/
private[this] lazy val windowFrameExpressionFactoryPairs = {
type FrameKey = (String, FrameType, Expression, Expression)
type ExpressionBuffer = mutable.Buffer[Expression]
val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
// Add a function and its function to the map for a given frame.
def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
val key = (tpe, fr.frameType, fr.lower, fr.upper)
val (es, fns) = framedFunctions.getOrElseUpdate(
key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
es += e
fns += fn
}
// Collect all valid window functions and group them by their frame.
windowExpression.foreach { x =>
x.foreach {
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
case f => sys.error(s"Unsupported window function: $f")
}
case _ =>
}
}
// Map the groups to a (unbound) expression and frame factory pair.
var numExpressions = 0
val timeZone = conf.sessionLocalTimeZone
framedFunctions.toSeq.map {
case (key, (expressions, functionSeq)) =>
val ordinal = numExpressions
val functions = functionSeq.toArray
// Construct an aggregate processor if we need one.
def processor = AggregateProcessor(
functions,
ordinal,
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
// Create the factory
val factory = key match {
// Offset Frame
case ("OFFSET", _, IntegerLiteral(offset), _) =>
target: InternalRow =>
new OffsetWindowFunctionFrame(
target,
ordinal,
// OFFSET frame functions are guaranteed be OffsetWindowFunctions.
functions.map(_.asInstanceOf[OffsetWindowFunction]),
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
offset)
// Entire Partition Frame.
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
target: InternalRow => {
new UnboundedWindowFunctionFrame(target, processor)
}
// Growing Frame.
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
target: InternalRow => {
new UnboundedPrecedingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, upper, timeZone))
}
// Shrinking Frame.
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
target: InternalRow => {
new UnboundedFollowingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, lower, timeZone))
}
// Moving Frame.
case ("AGGREGATE", frameType, lower, upper) =>
target: InternalRow => {
new SlidingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, lower, timeZone),
createBoundOrdering(frameType, upper, timeZone))
}
}
// Keep track of the number of expressions. This is a side-effect in a map...
numExpressions += expressions.size
// Create the Frame Expression - Factory pair.
(expressions, factory)
}
}
/**
* Create the resulting projection.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param expressions unbound ordered function expressions.
* @return the final resulting projection.
*/
private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
val references = expressions.zipWithIndex.map{ case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
child.output)
}
protected override def doExecute(): RDD[InternalRow] = {
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)

View file

@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.window
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType}
abstract class WindowExecBase(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan) extends UnaryExecNode {
/**
* Create the resulting projection.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param expressions unbound ordered function expressions.
* @return the final resulting projection.
*/
protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
val references = expressions.zipWithIndex.map { case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
child.output)
}
/**
* Create a bound ordering object for a given frame type and offset. A bound ordering object is
* used to determine which input row lies within the frame boundaries of an output row.
*
* This method uses Code Generation. It can only be used on the executor side.
*
* @param frame to evaluate. This can either be a Row or Range frame.
* @param bound with respect to the row.
* @param timeZone the session local timezone for time related calculations.
* @return a bound ordering object.
*/
private def createBoundOrdering(
frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
(frame, bound) match {
case (RowFrame, CurrentRow) =>
RowBoundOrdering(0)
case (RowFrame, IntegerLiteral(offset)) =>
RowBoundOrdering(offset)
case (RangeFrame, CurrentRow) =>
val ordering = newOrdering(orderSpec, child.output)
RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
// Use only the first order expression when the offset is non-null.
val sortExpr = orderSpec.head
val expr = sortExpr.child
// Create the projection which returns the current 'value'.
val current = newMutableProjection(expr :: Nil, child.output)
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
case Descending => UnaryMinus(offset)
case Ascending => offset
}
// Create the projection which returns the current 'value' modified by adding the offset.
val boundExpr = (expr.dataType, boundOffset.dataType) match {
case (DateType, IntegerType) => DateAdd(expr, boundOffset)
case (TimestampType, CalendarIntervalType) =>
TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a == b => Add(expr, boundOffset)
}
val bound = newMutableProjection(boundExpr :: Nil, child.output)
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
val ordering = newOrdering(boundSortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case (RangeFrame, _) =>
sys.error("Non-Zero range offsets are not supported for windows " +
"with multiple order expressions.")
}
}
/**
* Collection containing an entry for each window frame to process. Each entry contains a frame's
* [[WindowExpression]]s and factory function for the WindowFrameFunction.
*/
protected lazy val windowFrameExpressionFactoryPairs = {
type FrameKey = (String, FrameType, Expression, Expression)
type ExpressionBuffer = mutable.Buffer[Expression]
val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
// Add a function and its function to the map for a given frame.
def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
val key = (tpe, fr.frameType, fr.lower, fr.upper)
val (es, fns) = framedFunctions.getOrElseUpdate(
key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
es += e
fns += fn
}
// Collect all valid window functions and group them by their frame.
windowExpression.foreach { x =>
x.foreach {
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
case f: PythonUDF => collect("AGGREGATE", frame, e, f)
case f => sys.error(s"Unsupported window function: $f")
}
case _ =>
}
}
// Map the groups to a (unbound) expression and frame factory pair.
var numExpressions = 0
val timeZone = conf.sessionLocalTimeZone
framedFunctions.toSeq.map {
case (key, (expressions, functionSeq)) =>
val ordinal = numExpressions
val functions = functionSeq.toArray
// Construct an aggregate processor if we need one.
// Currently we don't allow mixing of Pandas UDF and SQL aggregation functions
// in a single Window physical node. Therefore, we can assume no SQL aggregation
// functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL
// aggregation function in a single physical node.
def processor = if (functions.exists(_.isInstanceOf[PythonUDF])) {
null
} else {
AggregateProcessor(
functions,
ordinal,
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
}
// Create the factory
val factory = key match {
// Offset Frame
case ("OFFSET", _, IntegerLiteral(offset), _) =>
target: InternalRow =>
new OffsetWindowFunctionFrame(
target,
ordinal,
// OFFSET frame functions are guaranteed be OffsetWindowFunctions.
functions.map(_.asInstanceOf[OffsetWindowFunction]),
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
offset)
// Entire Partition Frame.
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
target: InternalRow => {
new UnboundedWindowFunctionFrame(target, processor)
}
// Growing Frame.
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
target: InternalRow => {
new UnboundedPrecedingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, upper, timeZone))
}
// Shrinking Frame.
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
target: InternalRow => {
new UnboundedFollowingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, lower, timeZone))
}
// Moving Frame.
case ("AGGREGATE", frameType, lower, upper) =>
target: InternalRow => {
new SlidingWindowFunctionFrame(
target,
processor,
createBoundOrdering(frameType, lower, timeZone),
createBoundOrdering(frameType, upper, timeZone))
}
}
// Keep track of the number of expressions. This is a side-effect in a map...
numExpressions += expressions.size
// Create the Frame Expression - Factory pair.
(expressions, factory)
}
}
}

View file

@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
* Before use a frame must be prepared by passing it all the rows in the current partition. After
* preparation the update method can be called to fill the output rows.
*/
private[window] abstract class WindowFunctionFrame {
abstract class WindowFunctionFrame {
/**
* Prepare the frame for calculating the results for a partition.
*
@ -42,6 +42,20 @@ private[window] abstract class WindowFunctionFrame {
* Write the current results to the target row.
*/
def write(index: Int, current: InternalRow): Unit
/**
* The current lower window bound in the row array (inclusive).
*
* This should be called after the current row is updated via [[write]]
*/
def currentLowerBound(): Int
/**
* The current row index of the upper window bound in the row array (exclusive)
*
* This should be called after the current row is updated via [[write]]
*/
def currentUpperBound(): Int
}
object WindowFunctionFrame {
@ -62,7 +76,7 @@ object WindowFunctionFrame {
* @param newMutableProjection function used to create the projection.
* @param offset by which rows get moved within a partition.
*/
private[window] final class OffsetWindowFunctionFrame(
final class OffsetWindowFunctionFrame(
target: InternalRow,
ordinal: Int,
expressions: Array[OffsetWindowFunction],
@ -137,6 +151,10 @@ private[window] final class OffsetWindowFunctionFrame(
}
inputIndex += 1
}
override def currentLowerBound(): Int = throw new UnsupportedOperationException()
override def currentUpperBound(): Int = throw new UnsupportedOperationException()
}
/**
@ -148,7 +166,7 @@ private[window] final class OffsetWindowFunctionFrame(
* @param lbound comparator used to identify the lower bound of an output row.
* @param ubound comparator used to identify the upper bound of an output row.
*/
private[window] final class SlidingWindowFunctionFrame(
final class SlidingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
lbound: BoundOrdering,
@ -169,25 +187,25 @@ private[window] final class SlidingWindowFunctionFrame(
/** The rows within current sliding window. */
private[this] val buffer = new util.ArrayDeque[InternalRow]()
/**
* Index of the first input row with a value greater than the upper bound of the current
* output row.
*/
private[this] var inputHighIndex = 0
/**
* Index of the first input row with a value equal to or greater than the lower bound of the
* current output row.
*/
private[this] var inputLowIndex = 0
private[this] var lowerBound = 0
/**
* Index of the first input row with a value greater than the upper bound of the current
* output row.
*/
private[this] var upperBound = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex = 0
inputLowIndex = 0
lowerBound = 0
upperBound = 0
buffer.clear()
}
@ -197,27 +215,27 @@ private[window] final class SlidingWindowFunctionFrame(
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
while (!buffer.isEmpty && lbound.compare(buffer.peek(), lowerBound, current, index) < 0) {
buffer.remove()
inputLowIndex += 1
lowerBound += 1
bufferUpdated = true
}
// Add all rows to the buffer for which the input row value is equal to or less than
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
inputLowIndex += 1
while (nextRow != null && ubound.compare(nextRow, upperBound, current, index) <= 0) {
if (lbound.compare(nextRow, lowerBound, current, index) < 0) {
lowerBound += 1
} else {
buffer.add(nextRow.copy())
bufferUpdated = true
}
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex += 1
upperBound += 1
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
if (processor != null && bufferUpdated) {
processor.initialize(input.length)
val iter = buffer.iterator()
while (iter.hasNext) {
@ -226,6 +244,10 @@ private[window] final class SlidingWindowFunctionFrame(
processor.evaluate(target)
}
}
override def currentLowerBound(): Int = lowerBound
override def currentUpperBound(): Int = upperBound
}
/**
@ -239,29 +261,41 @@ private[window] final class SlidingWindowFunctionFrame(
* @param target to write results to.
* @param processor to calculate the row values with.
*/
private[window] final class UnboundedWindowFunctionFrame(
final class UnboundedWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor)
extends WindowFunctionFrame {
val lowerBound: Int = 0
var upperBound: Int = 0
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
if (processor != null) {
processor.initialize(rows.length)
val iterator = rows.generateIterator()
while (iterator.hasNext) {
processor.update(iterator.next())
}
}
upperBound = rows.length
}
/** Write the frame columns for the current row to the given target row. */
override def write(index: Int, current: InternalRow): Unit = {
// Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
// for each row.
if (processor != null) {
processor.evaluate(target)
}
}
override def currentLowerBound(): Int = lowerBound
override def currentUpperBound(): Int = upperBound
}
/**
* The UnboundPreceding window frame calculates frames with the following SQL form:
* ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
@ -276,7 +310,7 @@ private[window] final class UnboundedWindowFunctionFrame(
* @param processor to calculate the row values with.
* @param ubound comparator used to identify the upper bound of an output row.
*/
private[window] final class UnboundedPrecedingWindowFunctionFrame(
final class UnboundedPrecedingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
ubound: BoundOrdering)
@ -308,8 +342,10 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
nextRow = inputIterator.next()
}
if (processor != null) {
processor.initialize(input.length)
}
}
/** Write the frame columns for the current row to the given target row. */
override def write(index: Int, current: InternalRow): Unit = {
@ -318,17 +354,23 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
// Add all rows to the aggregates for which the input row value is equal to or less than
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
if (processor != null) {
processor.update(nextRow)
}
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
bufferUpdated = true
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
if (processor != null && bufferUpdated) {
processor.evaluate(target)
}
}
override def currentLowerBound(): Int = 0
override def currentUpperBound(): Int = inputIndex
}
/**
@ -347,7 +389,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
* @param processor to calculate the row values with.
* @param lbound comparator used to identify the lower bound of an output row.
*/
private[window] final class UnboundedFollowingWindowFunctionFrame(
final class UnboundedFollowingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
lbound: BoundOrdering)
@ -384,7 +426,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
if (processor != null && bufferUpdated) {
processor.initialize(input.length)
if (nextRow != null) {
processor.update(nextRow)
@ -395,4 +437,8 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
processor.evaluate(target)
}
}
override def currentLowerBound(): Int = inputIndex
override def currentUpperBound(): Int = input.length
}