[SPARK-35551][SQL] Handle the COUNT bug for lateral subqueries

### What changes were proposed in this pull request?
This PR modifies `DecorrelateInnerQuery` to handle the COUNT bug for lateral subqueries. Similar to SPARK-15370, rewriting lateral subqueries as joins can change the semantics of the subquery and lead to incorrect answers.

However we can't reuse the existing code to handle the count bug for correlated scalar subqueries because it assumes the subquery to have a specific shape (either with Filter + Aggregate or Aggregate as the root node). Instead, this PR proposes a more generic way to handle the COUNT bug. If an Aggregate is subject to the COUNT bug, we insert a left outer domain join between the outer query and the aggregate with a `alwaysTrue` marker and rewrite the final result conditioning on the marker. For example:

```sql
-- t1: [(0, 1), (1, 2)]
-- t2: [(0, 2), (0, 3)]
select * from t1 left outer join lateral (select count(*) from t2 where t2.c1 = t1.c1)
```

Without count bug handling, the query plan is
```
Project [c1#44, c2#45, count(1)#53L]
+- Join LeftOuter, (c1#48 = c1#44)
   :- LocalRelation [c1#44, c2#45]
   +- Aggregate [c1#48], [count(1) AS count(1)#53L, c1#48]
      +- LocalRelation [c1#48]
```
and the answer is wrong:
```
+---+---+--------+
|c1 |c2 |count(1)|
+---+---+--------+
|0  |1  |2       |
|1  |2  |null    |
+---+---+--------+
```

With the count bug handling:
```
Project [c1#1, c2#2, count(1)#10L]
+- Join LeftOuter, (c1#34 <=> c1#1)
   :- LocalRelation [c1#1, c2#2]
   +- Project [if (isnull(alwaysTrue#32)) 0 else count(1)#33L AS count(1)#10L, c1#34]
      +- Join LeftOuter, (c1#5 = c1#34)
         :- Aggregate [c1#1], [c1#1 AS c1#34]
         :  +- LocalRelation [c1#1]
         +- Aggregate [c1#5], [count(1) AS count(1)#33L, c1#5, true AS alwaysTrue#32]
            +- LocalRelation [c1#5]
```
and we have the correct answer:
```
+---+---+--------+
|c1 |c2 |count(1)|
+---+---+--------+
|0  |1  |2       |
|1  |2  |0       |
+---+---+--------+
```

### Why are the changes needed?
Fix a correctness bug with lateral join rewrite.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added SQL query tests. The results are consistent with Postgres' results.

Closes #33070 from allisonwang-db/spark-35551-lateral-count-bug.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
allisonwang-db 2021-07-13 17:35:03 +08:00 committed by Wenchen Fan
parent f8a80c42ce
commit 4f760f2b1f
7 changed files with 488 additions and 48 deletions

View file

@ -50,4 +50,6 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
}

View file

@ -51,4 +51,6 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
override def removed(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans._
@ -126,11 +128,11 @@ object DecorrelateInnerQuery extends PredicateHelper {
* E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y}
*/
private def collectEquivalentOuterReferences(
expressions: Seq[Expression]): Map[Attribute, Attribute] = {
expressions.collect {
expressions: Seq[Expression]): AttributeMap[Attribute] = {
AttributeMap(expressions.collect {
case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute)
case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute)
}.toMap
})
}
/**
@ -138,7 +140,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
*/
private def replaceOuterReference[E <: Expression](
expression: E,
outerReferenceMap: Map[Attribute, Attribute]): E = {
outerReferenceMap: AttributeMap[Attribute]): E = {
expression.transformWithPruning(_.containsPattern(OUTER_REFERENCE)) {
case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o)
}.asInstanceOf[E]
@ -150,7 +152,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
*/
private def replaceOuterReferences[E <: Expression](
expressions: Seq[E],
outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = {
outerReferenceMap: AttributeMap[Attribute]): Seq[E] = {
expressions.map(replaceOuterReference(_, outerReferenceMap))
}
@ -212,14 +214,40 @@ object DecorrelateInnerQuery extends PredicateHelper {
}
/**
* Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query.
* Rewrite all [[DomainJoin]]s in the inner query to actual joins with the outer query.
*/
def rewriteDomainJoins(
outerPlan: LogicalPlan,
innerPlan: LogicalPlan,
conditions: Seq[Expression]): LogicalPlan = innerPlan match {
case d @ DomainJoin(domainAttrs, child) =>
case d @ DomainJoin(domainAttrs, child, joinType, condition) =>
val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)
val newChild = joinType match {
// Left outer domain joins are used to handle the COUNT bug.
case LeftOuter =>
// Replace the attributes in the domain join condition with the actual outer expressions
// and use the new join conditions to rewrite domain joins in its child. For example:
// DomainJoin [c'] LeftOuter (a = c') with domainAttrMap: { c' -> _1 }.
// Then the new conditions to use will be [(a = _1)].
assert(condition.isDefined,
s"LeftOuter domain join should always have the join condition defined:\n$d")
val newCond = condition.get.transform {
case a: Attribute => domainAttrMap.getOrElse(a, a)
}
// Recursively rewrite domain joins using the new conditions.
rewriteDomainJoins(outerPlan, child, splitConjunctivePredicates(newCond))
case Inner =>
// The decorrelation framework adds domain inner joins by traversing down the plan tree
// recursively until it reaches a node that is not correlated with the outer query.
// So the child node of a domain inner join shouldn't contain another domain join.
assert(child.find(_.isInstanceOf[DomainJoin]).isEmpty,
s"Child of a domain inner join shouldn't contain another domain join.\n$child")
child
case o =>
throw new IllegalStateException(s"Unexpected domain join type $o")
}
// We should only rewrite a domain join when all corresponding outer plan attributes
// can be found from the join condition.
if (domainAttrMap.size == domainAttrs.size) {
@ -232,21 +260,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
// DomainJoin [a', b'] => Aggregate [a, b] [a AS a', b AS b']
// +- Relation [a, b]
val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
child match {
newChild match {
// A special optimization for OneRowRelation.
// TODO: add a more general rule to optimize join with OneRowRelation.
case _: OneRowRelation => domain
// Construct a domain join.
// Join Inner
// :- Inner Query
// +- Domain
case _ =>
// The decorrelation framework adds domain joins by traversing down the plan tree
// recursively until it reaches a node that is not correlated with the outer query.
// So the child node of a domain join shouldn't contain another domain join.
assert(child.find(_.isInstanceOf[DomainJoin]).isEmpty,
s"Child of a domain join shouldn't contain another domain join.\n$child")
Join(child, domain, Inner, None, JoinHint.NONE)
// Join joinType condition
// :- Domain
// +- Inner Query
case _ => Join(domain, newChild, joinType, condition, JoinHint.NONE)
}
} else {
throw QueryExecutionErrors.cannotRewriteDomainJoinWithConditionsError(conditions, d)
@ -257,7 +279,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
def apply(
innerPlan: LogicalPlan,
outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
outerPlan: LogicalPlan,
handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
val outputPlanInputAttrs = outerPlan.inputSet
// The return type of the recursion.
@ -265,7 +288,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
// The second parameter is a list of join conditions with the outer query.
// The third parameter is a mapping between the outer references and equivalent
// expressions from the inner query that is used to replace outer references.
type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute])
type ReturnType = (LogicalPlan, Seq[Expression], AttributeMap[Attribute])
// Decorrelate the input plan with a set of parent outer references and a boolean flag
// indicating whether the result of the plan will be aggregated. Steps:
@ -288,7 +311,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
// If there is no outer references from the parent nodes, it means all outer
// attributes can be substituted by attributes from the inner plan. So no
// domain join is needed.
(plan, Nil, Map.empty[Attribute, Attribute])
(plan, Nil, AttributeMap.empty[Attribute])
} else {
// Build the domain join with the parent outer references.
val attributes = parentOuterReferences.toSeq
@ -310,7 +333,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
val conditions = outerReferenceMap.map {
case (o, a) => EqualNullSafe(a, OuterReference(o))
}
(domainJoin, conditions.toSeq, outerReferenceMap)
(domainJoin, conditions.toSeq, AttributeMap(outerReferenceMap))
}
} else {
plan match {
@ -428,7 +451,134 @@ object DecorrelateInnerQuery extends PredicateHelper {
groupingExpressions = newGroupingExpr ++ referencesToAdd,
aggregateExpressions = newAggExpr ++ referencesToAdd,
child = newChild)
(newAggregate, joinCond, outerReferenceMap)
// Preserving domain attributes over an Aggregate with an empty grouping expression
// is subject to the "COUNT bug" that can lead to wrong answer:
//
// Suppose the original query is:
// SELECT a, (SELECT COUNT(*) cnt FROM t2 WHERE t1.a = t2.c) FROM t1
//
// Decorrelated plan:
// Project [a, scalar-subquery [a = c]]
// : +- Aggregate [c] [count(*) AS cnt, c]
// : +- Relation [c, d]
// +- Relation [a, b]
//
// After rewrite:
// Project [a, cnt]
// +- Join LeftOuter (a = c)
// :- Relation [a, b]
// +- Aggregate [c] [count(*) AS cnt, c]
// +- Relation [c, d]
//
// T1 T2 T2' (GROUP BY c)
// +---+---+ +---+---+ +---+-----+
// | a | b | | c | d | | c | cnt |
// +---+---+ +---+---+ +---+-----+
// | 0 | 1 | | 0 | 2 | | 0 | 2 |
// | 1 | 2 | | 0 | 3 | +---+-----+
// +---+---+ +---+---+
//
// T1 nested loop join T2 T1 left outer join T2'
// on (a = c): on (a = c):
// +---+-----+ +---+-----++
// | a | cnt | | a | cnt |
// +---+-----+ +---+------+
// | 0 | 2 | | 0 | 2 |
// | 1 | 0 | <--- correct | 1 | null | <--- wrong result
// +---+-----+ +---+------+
//
// If an aggregate is subject to the COUNT bug:
// 1) add a column `true AS alwaysTrue` to the result of the aggregate
// 2) insert a left outer domain join between the outer query and this aggregate
// 3) rewrite the original aggregate's output column using the default value of the
// aggregate function and the alwaysTrue column.
//
// For example, T1 left outer join T2' with `alwaysTrue` marker:
// +---+------+------------+--------------------------------+
// | c | cnt | alwaysTrue | if(isnull(alwaysTrue), 0, cnt) |
// +---+------+------------+--------------------------------+
// | 0 | 2 | true | 2 |
// | 0 | null | null | 0 | <--- correct result
// +---+------+------------+--------------------------------+
if (groupingExpressions.isEmpty && handleCountBug) {
// Evaluate the aggregate expressions with zero tuples.
val resultMap = RewriteCorrelatedScalarSubquery.evalAggregateOnZeroTups(newAggregate)
val alwaysTrue = Alias(Literal.TrueLiteral, "alwaysTrue")()
val alwaysTrueRef = alwaysTrue.toAttribute.withNullability(true)
val expressions = ArrayBuffer.empty[NamedExpression]
// Create new aliases for aggregate expressions that have non-null default
// values and reconstruct the output with the `alwaysTrue` marker.
val projectList = newAggregate.aggregateExpressions.map { a =>
resultMap.get(a.exprId) match {
// Aggregate expression is not subject to the count bug.
case Some(Literal(null, _)) | None =>
expressions += a
// The attribute is nullable since it is from the right-hand side of a
// left outer join.
a.toAttribute.withNullability(true)
case Some(default) =>
assert(a.isInstanceOf[Alias], s"Cannot have non-aliased expression $a in " +
s"aggregate that evaluates to non-null value with zero tuples.")
val newAttr = a.newInstance()
val ref = newAttr.toAttribute.withNullability(true)
expressions += newAttr
Alias(If(IsNull(alwaysTrueRef), default, ref), a.name)(a.exprId)
}
}
// Insert a placeholder left outer domain join between the outer query and
// and aggregate node and use the current collected join conditions as the
// left outer join condition.
//
// Original subquery:
// Aggregate [count(1) AS cnt]
// +- Filter (a = outer(c))
// +- Relation [a, b]
//
// After decorrelation and before COUNT bug handling:
// Aggregate [a] [count(1) AS cnt, a]
// +- Relation [a, b]
//
// joinCond with the outer query: (a = outer(c))
//
// Handle the COUNT bug:
// Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
// +- DomainJoin [c'] LeftOuter (a = c')
// +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
// +- Relation [a, b]
//
// New joinCond with the outer query: (c' <=> outer(c)), and the DomainJoin
// will be written as:
// Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
// +- Join LeftOuter (a = c')
// :- Aggregate [c] [c AS c']
// : +- OuterQuery [c, d]
// +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
// +- Relation [a, b]
//
val agg = newAggregate.copy(aggregateExpressions = expressions.toSeq :+ alwaysTrue)
// Find all outer references that are used in the join conditions.
val outerAttrs = collectOuterReferences(joinCond).toSeq
// Create new instance of the outer attributes as if they are generated inside
// the subquery by a left outer join with the outer query. Use new instance here
// to avoid conflicting join attributes with the inner query.
val domainAttrs = outerAttrs.map(_.newInstance())
val mapping = AttributeMap(outerAttrs.zip(domainAttrs))
// Use the current join conditions returned from the recursive call as the join
// conditions for the left outer join. All outer references in the join
// conditions are replaced by the newly created domain attributes.
val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And)
val domainJoin = DomainJoin(domainAttrs, agg, LeftOuter, condition)
// Original domain attributes preserved through Aggregate are no longer needed.
val newProjectList = projectList.filter(!referencesToAdd.contains(_))
val project = Project(newProjectList ++ domainAttrs, domainJoin)
val newJoinCond = outerAttrs.zip(domainAttrs).map { case (outer, inner) =>
EqualNullSafe(inner, OuterReference(outer))
}
(project, newJoinCond, mapping)
} else {
(newAggregate, joinCond, outerReferenceMap)
}
case j @ Join(left, right, joinType, condition, _) =>
val outerReferences = collectOuterReferences(j.expressions)
@ -446,12 +596,12 @@ object DecorrelateInnerQuery extends PredicateHelper {
val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
decorrelate(left, newOuterReferences, aggregated)
} else {
(left, Nil, Map.empty[Attribute, Attribute])
(left, Nil, AttributeMap.empty[Attribute])
}
val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) {
decorrelate(right, newOuterReferences, aggregated)
} else {
(right, Nil, Map.empty[Attribute, Attribute])
(right, Nil, AttributeMap.empty[Attribute])
}
val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
val newJoinCond = leftJoinCond ++ rightJoinCond

View file

@ -296,9 +296,12 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
if (newCond.isEmpty) oldCond else newCond
}
def decorrelate(sub: LogicalPlan, outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
def decorrelate(
sub: LogicalPlan,
outer: LogicalPlan,
handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
if (SQLConf.get.decorrelateInnerQueryEnabled) {
DecorrelateInnerQuery(sub, outer)
DecorrelateInnerQuery(sub, outer, handleCountBug)
} else {
pullOutCorrelatedPredicates(sub, outer)
}
@ -315,7 +318,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
case LateralSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan)
val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
}
}
@ -396,7 +399,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
/**
* Statically evaluate an expression containing one or more aggregates on an empty input.
*/
private def evalAggOnZeroTups(expr: Expression) : Expression = {
private def evalAggExprOnZeroTups(expr: Expression) : Expression = {
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
@ -410,6 +413,24 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
tryEvalExpr(rewrittenExpr)
}
/**
* Statically evaluate an [[Aggregate]] on an empty input and return a mapping
* between its output attribute expression ID and evaluated result.
*/
def evalAggregateOnZeroTups(a: Aggregate): Map[ExprId, Expression] = {
// Some of the expressions under the Aggregate node are the join columns
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
a.aggregateExpressions.map {
case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
case alias @ Alias(_: AttributeReference, _) =>
(alias.exprId, Literal.create(null, alias.dataType))
case alias @ Alias(l: Literal, _) =>
(alias.exprId, l.copy(value = null))
case ne => (ne.exprId, evalAggExprOnZeroTups(ne))
}.toMap
}
/**
* Statically evaluate a scalar subquery on an empty input.
*
@ -454,18 +475,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
}
case Aggregate(_, aggExprs, _) =>
// Some of the expressions under the Aggregate node are the join columns
// for joining with the outer query block. Fill those expressions in with
// nulls and statically evaluate the remainder.
aggExprs.map {
case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
case alias @ Alias(_: AttributeReference, _) =>
(alias.exprId, Literal.create(null, alias.dataType))
case alias @ Alias(l: Literal, _) =>
(alias.exprId, l.copy(value = null))
case ne => (ne.exprId, evalAggOnZeroTups(ne))
}.toMap
case a: Aggregate =>
evalAggregateOnZeroTups(a)
case l: LeafNode =>
l.output.map(a => (a.exprId, Literal.create(null, a.dataType))).toMap
@ -695,7 +706,6 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(LATERAL_JOIN)) {
case LateralJoin(left, LateralSubquery(sub, _, _, joinCond), joinType, condition) =>
// TODO(SPARK-35551): handle the COUNT bug
val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond)
val newCond = (condition ++ joinCond).reduceOption(And)
Join(left, newRight, joinType, newCond, JoinHint.NONE)

View file

@ -1449,9 +1449,21 @@ case class CollectMetrics(
* A placeholder for domain join that can be added when decorrelating subqueries.
* It should be rewritten during the optimization phase.
*/
case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ domainAttrs
case class DomainJoin(
domainAttrs: Seq[Attribute],
child: LogicalPlan,
joinType: JoinType = Inner,
condition: Option[Expression] = None) extends UnaryNode {
require(Seq(Inner, LeftOuter).contains(joinType), s"Unsupported domain join type $joinType")
override def output: Seq[Attribute] = joinType match {
case LeftOuter => domainAttrs ++ child.output.map(_.withNullability(true))
case _ => domainAttrs ++ child.output
}
override def producedAttributes: AttributeSet = AttributeSet(domainAttrs)
override protected def withNewChildInternal(newChild: LogicalPlan): DomainJoin =
copy(child = newChild)
}

View file

@ -86,8 +86,65 @@ SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a));
-- lateral join inside correlated subquery
SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE c1 = t1.c1);
-- TODO(SPARK-35551): handle the COUNT bug (the expected result should be (1, 2, 0))
SELECT * FROM t1, LATERAL (SELECT COUNT(*) AS cnt FROM t2 WHERE c1 = t1.c1) WHERE cnt = 0;
-- COUNT bug with a single aggregate expression
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1);
-- COUNT bug with multiple aggregate expressions
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt, SUM(c2) sum FROM t2 WHERE c1 = t1.c1);
-- COUNT bug without count aggregate
SELECT * FROM t1, LATERAL (SELECT SUM(c2) IS NULL FROM t2 WHERE t1.c1 = t2.c1);
-- COUNT bug with complex aggregate expressions
SELECT * FROM t1, LATERAL (SELECT COUNT(*) + CASE WHEN sum(c2) IS NULL THEN 0 ELSE sum(c2) END FROM t2 WHERE t1.c1 = t2.c1);
-- COUNT bug with non-empty group by columns (should not handle the count bug)
SELECT * FROM t1, LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
SELECT * FROM t1, LATERAL (SELECT c2, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c2);
-- COUNT bug with different join types
SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
SELECT * FROM t1 CROSS JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
-- COUNT bug with group by columns and different join types
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
-- COUNT bug with non-empty join conditions
SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) ON cnt + 1 = c1;
-- COUNT bug with self join
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1);
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt > 0);
-- COUNT bug with multiple aggregates
SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1));
SELECT * FROM t1, LATERAL (SELECT COUNT(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1));
SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1));
SELECT * FROM t1, LATERAL (
SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
JOIN t2 ON cnt = t2.c1
);
-- COUNT bug with correlated predicates above the left outer join
SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1);
SELECT * FROM t1, LATERAL (SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1);
SELECT * FROM t1, LATERAL (
SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
WHERE cnt = c1 - 1 GROUP BY cnt
);
-- COUNT bug with joins in the subquery
SELECT * FROM t1, LATERAL (
SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
JOIN t2 ON cnt = t2.c1
);
SELECT * FROM t1, LATERAL (
SELECT l.cnt + r.cnt
FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) l
JOIN (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) r
);
-- lateral subquery with group by
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT MIN(c2) FROM t2 WHERE c1 = t1.c1 GROUP BY c1);

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 44
-- Number of queries: 66
-- !query
@ -389,12 +389,219 @@ struct<c1:int,c2:int>
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) AS cnt FROM t2 WHERE c1 = t1.c1) WHERE cnt = 0
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 2
1 2 0
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt, SUM(c2) sum FROM t2 WHERE c1 = t1.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint,sum:bigint>
-- !query output
0 1 2 5
1 2 0 NULL
-- !query
SELECT * FROM t1, LATERAL (SELECT SUM(c2) IS NULL FROM t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,(sum(c2) IS NULL):boolean>
-- !query output
0 1 false
1 2 true
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) + CASE WHEN sum(c2) IS NULL THEN 0 ELSE sum(c2) END FROM t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,(count(1) + CASE WHEN (sum(c2) IS NULL) THEN 0 ELSE sum(c2) END):bigint>
-- !query output
0 1 7
1 2 0
-- !query
SELECT * FROM t1, LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
-- !query schema
struct<c1:int,c2:int,c1:int,cnt:bigint>
-- !query output
0 1 0 2
-- !query
SELECT * FROM t1, LATERAL (SELECT c2, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c2)
-- !query schema
struct<c1:int,c2:int,c2:int,cnt:bigint>
-- !query output
0 1 2 1
0 1 3 1
-- !query
SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 2
1 2 0
-- !query
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 2
1 2 0
-- !query
SELECT * FROM t1 CROSS JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 2
1 2 0
-- !query
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
-- !query schema
struct<c1:int,c2:int,c1:int,cnt:bigint>
-- !query output
0 1 0 2
1 2 NULL NULL
-- !query
SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
-- !query schema
struct<c1:int,c2:int,c1:int,cnt:bigint>
-- !query output
0 1 0 2
-- !query
SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) ON cnt + 1 = c1
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
1 2 0
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 1
1 2 1
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt > 0)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
0 1 1
1 2 1
-- !query
SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1))
-- !query schema
struct<c1:int,c2:int,sum(cnt):bigint>
-- !query output
0 1 2
1 2 0
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1))
-- !query schema
struct<c1:int,c2:int,count(cnt):bigint>
-- !query output
0 1 1
1 2 1
-- !query
SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1))
-- !query schema
struct<c1:int,c2:int,sum(cnt):bigint>
-- !query output
0 1 2
1 2 NULL
-- !query
SELECT * FROM t1, LATERAL (
SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
JOIN t2 ON cnt = t2.c1
)
-- !query schema
struct<c1:int,c2:int,count(1):bigint>
-- !query output
0 1 0
1 2 2
-- !query
SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1)
-- !query schema
struct<c1:int,c2:int,cnt:bigint>
-- !query output
1 2 0
-- !query
SELECT * FROM t1, LATERAL (SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1)
-- !query schema
struct<c1:int,c2:int,count(1):bigint>
-- !query output
0 1 0
1 2 1
-- !query
SELECT * FROM t1, LATERAL (
SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
WHERE cnt = c1 - 1 GROUP BY cnt
)
-- !query schema
struct<c1:int,c2:int,count(1):bigint>
-- !query output
1 2 1
-- !query
SELECT * FROM t1, LATERAL (
SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
JOIN t2 ON cnt = t2.c1
)
-- !query schema
struct<c1:int,c2:int,cnt:bigint,c1:int,c2:int>
-- !query output
1 2 0 0 2
1 2 0 0 3
-- !query
SELECT * FROM t1, LATERAL (
SELECT l.cnt + r.cnt
FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) l
JOIN (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) r
)
-- !query schema
struct<c1:int,c2:int,(cnt + cnt):bigint>
-- !query output
0 1 4
1 2 0
-- !query
SELECT * FROM t1 LEFT JOIN LATERAL (SELECT MIN(c2) FROM t2 WHERE c1 = t1.c1 GROUP BY c1)