diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 70a3885d21..1e934d0aa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1286,8 +1286,10 @@ class Analyzer( resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) - case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved => + val expr = resolveSubQuery(l, plans)((plan, exprs) => { + ListQuery(plan, exprs, exprId, plan.output) + }) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 06d8350db9..9ffe646b5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -402,7 +402,7 @@ object TypeCoercion { // Handle type casting required between value expression and subquery output // in IN subquery. - case i @ In(a, Seq(ListQuery(sub, children, exprId))) + case i @ In(a, Seq(ListQuery(sub, children, exprId, _))) if !i.resolved && flattenExpr(a).length == sub.output.length => // LHS is the value expression of IN subquery. val lhs = flattenExpr(a) @@ -434,7 +434,8 @@ object TypeCoercion { case _ => CreateStruct(castedLhs) } - In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + val newSub = Project(castedRhs, sub) + In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output))) } else { i } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f..613d6202b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -138,32 +138,33 @@ case class Not(child: Expression) case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { - list match { - case ListQuery(sub, _, _) :: Nil => - val valExprs = value match { - case cns: CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - if (valExprs.length != sub.output.length) { - TypeCheckResult.TypeCheckFailure( - s""" - |The number of columns in the left hand side of an IN subquery does not match the - |number of columns in the output of subquery. - |#columns in left hand side: ${valExprs.length}. - |#columns in right hand side: ${sub.output.length}. - |Left side columns: - |[${valExprs.map(_.sql).mkString(", ")}]. - |Right side columns: - |[${sub.output.map(_.sql).mkString(", ")}]. - """.stripMargin) - } else { - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + if (mismatchOpt.isDefined) { + list match { + case ListQuery(_, _, _, childOutputs) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) } - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != childOutputs.length) { + TypeCheckResult.TypeCheckFailure( + s""" + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${childOutputs.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin) + } else { + val mismatchedColumns = valExprs.zip(childOutputs).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } TypeCheckResult.TypeCheckFailure( s""" |The data type of one or more elements in the left hand side of an IN subquery @@ -173,20 +174,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |Left side: |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. - """.stripMargin) - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } - } - case _ => - val mismatchOpt = list.find(l => l.dataType != value.dataType) - if (mismatchOpt.isDefined) { + case _ => TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + s"${value.dataType} != ${mismatchOpt.get.dataType}") - } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") - } + } + } else { + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d7b493d521..c6146042ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -274,9 +274,15 @@ object ScalarSubquery { case class ListQuery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, - exprId: ExprId = NamedExpression.newExprId) + exprId: ExprId = NamedExpression.newExprId, + childOutputs: Seq[Attribute] = Seq.empty) extends SubqueryExpression(plan, children, exprId) with Unevaluable { - override def dataType: DataType = plan.schema.fields.head.dataType + override def dataType: DataType = if (childOutputs.length > 1) { + childOutputs.toStructType + } else { + childOutputs.head.dataType + } + override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" @@ -284,7 +290,8 @@ case class ListQuery( ListQuery( plan.canonicalized, children.map(_.canonicalized), - ExprId(0)) + ExprId(0), + childOutputs.map(_.canonicalized.asInstanceOf[Attribute])) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9dbb6b14aa..4386a10162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -68,11 +68,11 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. @@ -116,7 +116,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - case In(value, Seq(ListQuery(sub, conditions, _))) => + case In(value, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) @@ -227,9 +227,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) Exists(newPlan, newCond, exprId) - case ListQuery(sub, _, exprId) => + case ListQuery(sub, _, exprId, childOutputs) => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) - ListQuery(newPlan, newCond, exprId) + ListQuery(newPlan, newCond, exprId, childOutputs) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala new file mode 100644 index 0000000000..169b8737d8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{In, ListQuery} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class PullupCorrelatedPredicatesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("PullupCorrelatedPredicates", Once, + PullupCorrelatedPredicates) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.double) + val testRelation2 = LocalRelation('c.int, 'd.double) + + test("PullupCorrelatedPredicates should not produce unresolved plan") { + val correlatedSubquery = + testRelation2 + .where('b < 'd) + .select('c) + val outerQuery = + testRelation + .where(In('a, Seq(ListQuery(correlatedSubquery)))) + .select('a).analyze + assert(outerQuery.resolved) + + val optimized = Optimize.execute(outerQuery) + assert(optimized.resolved) + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out index 9ea9d3c4c6..70aeb9373f 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -80,8 +80,7 @@ number of columns in the output of subquery. Left side columns: [t1.`t1a`]. Right side columns: -[t2.`t2a`, t2.`t2b`]. - ; +[t2.`t2a`, t2.`t2b`].; -- !query 6 @@ -102,5 +101,4 @@ number of columns in the output of subquery. Left side columns: [t1.`t1a`, t1.`t1b`]. Right side columns: -[t2.`t2a`]. - ; +[t2.`t2a`].;