[SPARK-1442] [SQL] [FOLLOW-UP] Address minor comments in Window Function PR (#5604).
Address marmbrus and scwf's comments in #5604.
Author: Yin Huai <yhuai@databricks.com>
Closes #5945 from yhuai/windowFollowup and squashes the following commits:
0ef879d [Yin Huai] Add collectFirst to TreeNode.
2373968 [Yin Huai] wip
4a16df9 [Yin Huai] Address minor comments for [SPARK-1442].
(cherry picked from commit 5784c8d955
)
Signed-off-by: Michael Armbrust <michael@databricks.com>
This commit is contained in:
parent
ef835dc526
commit
9dcf4f78f4
|
@ -638,11 +638,10 @@ class Analyzer(
|
||||||
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
|
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
|
||||||
// First, we group window expressions based on their Window Spec.
|
// First, we group window expressions based on their Window Spec.
|
||||||
val groupedWindowExpression = windowExpressions.groupBy { expr =>
|
val groupedWindowExpression = windowExpressions.groupBy { expr =>
|
||||||
val windowExpression = expr.find {
|
val windowSpec = expr.collectFirst {
|
||||||
case window: WindowExpression => true
|
case window: WindowExpression => window.windowSpec
|
||||||
case other => false
|
}
|
||||||
}.map(_.asInstanceOf[WindowExpression].windowSpec)
|
windowSpec.getOrElse(
|
||||||
windowExpression.getOrElse(
|
|
||||||
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
|
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
|
||||||
}.toSeq
|
}.toSeq
|
||||||
|
|
||||||
|
@ -685,7 +684,7 @@ class Analyzer(
|
||||||
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
|
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
|
||||||
if child.resolved &&
|
if child.resolved &&
|
||||||
hasWindowFunction(aggregateExprs) &&
|
hasWindowFunction(aggregateExprs) &&
|
||||||
!a.expressions.exists(!_.resolved) =>
|
a.expressions.forall(_.resolved) =>
|
||||||
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
||||||
// Create an Aggregate operator to evaluate aggregation functions.
|
// Create an Aggregate operator to evaluate aggregation functions.
|
||||||
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
||||||
|
@ -702,7 +701,7 @@ class Analyzer(
|
||||||
// Aggregate without Having clause.
|
// Aggregate without Having clause.
|
||||||
case a @ Aggregate(groupingExprs, aggregateExprs, child)
|
case a @ Aggregate(groupingExprs, aggregateExprs, child)
|
||||||
if hasWindowFunction(aggregateExprs) &&
|
if hasWindowFunction(aggregateExprs) &&
|
||||||
!a.expressions.exists(!_.resolved) =>
|
a.expressions.forall(_.resolved) =>
|
||||||
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
||||||
// Create an Aggregate operator to evaluate aggregation functions.
|
// Create an Aggregate operator to evaluate aggregation functions.
|
||||||
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
||||||
|
|
|
@ -130,6 +130,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Finds and returns the first [[TreeNode]] of the tree for which the given partial function
|
||||||
|
* is defined (pre-order), and applies the partial function to it.
|
||||||
|
*/
|
||||||
|
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
|
||||||
|
val lifted = pf.lift
|
||||||
|
lifted(this).orElse {
|
||||||
|
children.foldLeft(None: Option[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a copy of this node where `f` has been applied to all the nodes children.
|
* Returns a copy of this node where `f` has been applied to all the nodes children.
|
||||||
*/
|
*/
|
||||||
|
@ -160,7 +171,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
|
||||||
val remainingNewChildren = newChildren.toBuffer
|
val remainingNewChildren = newChildren.toBuffer
|
||||||
val remainingOldChildren = children.toBuffer
|
val remainingOldChildren = children.toBuffer
|
||||||
val newArgs = productIterator.map {
|
val newArgs = productIterator.map {
|
||||||
// This rule is used to handle children is a input argument.
|
// Handle Seq[TreeNode] in TreeNode parameters.
|
||||||
case s: Seq[_] => s.map {
|
case s: Seq[_] => s.map {
|
||||||
case arg: TreeNode[_] if children contains arg =>
|
case arg: TreeNode[_] if children contains arg =>
|
||||||
val newChild = remainingNewChildren.remove(0)
|
val newChild = remainingNewChildren.remove(0)
|
||||||
|
|
|
@ -172,4 +172,54 @@ class TreeNodeSuite extends FunSuite {
|
||||||
expected = None
|
expected = None
|
||||||
assert(expected === actual)
|
assert(expected === actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("collectFirst") {
|
||||||
|
val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
|
||||||
|
|
||||||
|
// Collect the top node.
|
||||||
|
{
|
||||||
|
val actual = expression.collectFirst {
|
||||||
|
case add: Add => add
|
||||||
|
}
|
||||||
|
val expected =
|
||||||
|
Some(Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))))
|
||||||
|
assert(expected === actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the first children.
|
||||||
|
{
|
||||||
|
val actual = expression.collectFirst {
|
||||||
|
case l @ Literal(1, IntegerType) => l
|
||||||
|
}
|
||||||
|
val expected = Some(Literal(1))
|
||||||
|
assert(expected === actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect an internal node (Subtract).
|
||||||
|
{
|
||||||
|
val actual = expression.collectFirst {
|
||||||
|
case sub: Subtract => sub
|
||||||
|
}
|
||||||
|
val expected = Some(Subtract(Literal(3), Literal(4)))
|
||||||
|
assert(expected === actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect a leaf node.
|
||||||
|
{
|
||||||
|
val actual = expression.collectFirst {
|
||||||
|
case l @ Literal(3, IntegerType) => l
|
||||||
|
}
|
||||||
|
val expected = Some(Literal(3))
|
||||||
|
assert(expected === actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect nothing.
|
||||||
|
{
|
||||||
|
val actual = expression.collectFirst {
|
||||||
|
case l @ Literal(100, IntegerType) => l
|
||||||
|
}
|
||||||
|
val expected = None
|
||||||
|
assert(expected === actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue