Revert "[SPARK-20392][SQL] Set barrier to prevent re-entering a tree"

This reverts commit 8ce0d8ffb6.
This commit is contained in:
Wenchen Fan 2017-05-30 21:14:55 -07:00
parent 52ed9b289d
commit 1f5dddffa3
16 changed files with 144 additions and 151 deletions

View file

@ -166,15 +166,14 @@ class Analyzer(
Batch("Subquery", Once, Batch("Subquery", Once,
UpdateOuterReferences), UpdateOuterReferences),
Batch("Cleanup", fixedPoint, Batch("Cleanup", fixedPoint,
CleanupAliases, CleanupAliases)
EliminateBarriers)
) )
/** /**
* Analyze cte definitions and substitute child plan with analyzed cte definitions. * Analyze cte definitions and substitute child plan with analyzed cte definitions.
*/ */
object CTESubstitution extends Rule[LogicalPlan] { object CTESubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case With(child, relations) => case With(child, relations) =>
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
case (resolved, (name, relation)) => case (resolved, (name, relation)) =>
@ -202,7 +201,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions. * Substitute child plan with WindowSpecDefinitions.
*/ */
object WindowsSubstitution extends Rule[LogicalPlan] { object WindowsSubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Lookup WindowSpecDefinitions. This rule works with unresolved children. // Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) => case WithWindowDefinition(windowDefinitions, child) =>
child.transform { child.transform {
@ -244,7 +243,7 @@ class Analyzer(
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child) Aggregate(groups, assignAliases(aggs), child)
@ -634,7 +633,7 @@ class Analyzer(
case _ => plan case _ => plan
} }
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View => case v: View =>
@ -689,9 +688,7 @@ class Analyzer(
* Generate a new logical plan for the right child with different expression IDs * Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes. * for all conflicting attributes.
*/ */
private def dedupRight (left: LogicalPlan, oriRight: LogicalPlan): LogicalPlan = { private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
// Remove analysis barrier if any.
val right = EliminateBarriers(oriRight)
val conflictingAttributes = left.outputSet.intersect(right.outputSet) val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
s"between $left and $right") s"between $left and $right")
@ -734,7 +731,7 @@ class Analyzer(
* that this rule cannot handle. When that is the case, there must be another rule * that this rule cannot handle. When that is the case, there must be another rule
* that resolves these conflicts. Otherwise, the analysis will fail. * that resolves these conflicts. Otherwise, the analysis will fail.
*/ */
oriRight right
case Some((oldRelation, newRelation)) => case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp { val newRight = right transformUp {
@ -747,7 +744,7 @@ class Analyzer(
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
} }
} }
AnalysisBarrier(newRight) newRight
} }
} }
@ -808,7 +805,7 @@ class Analyzer(
} }
} }
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it. // If the projection list contains Stars, expand it.
@ -982,7 +979,7 @@ class Analyzer(
* have no effect on the results. * have no effect on the results.
*/ */
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY, // Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list. // which is a 1-base position of the projection list.
@ -1038,7 +1035,7 @@ class Analyzer(
}} }}
} }
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case agg @ Aggregate(groups, aggs, child) case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) => groups.exists(!_.resolved) =>
@ -1062,13 +1059,11 @@ class Analyzer(
* The HAVING clause could also used a grouping columns that is not presented in the SELECT. * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/ */
object ResolveMissingReferences extends Rule[LogicalPlan] { object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa case sa @ Sort(_, _, child: Aggregate) => sa
case s @ Sort(order, _, orgChild) if !s.resolved && orgChild.resolved => case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
val child = EliminateBarriers(orgChild)
try { try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@ -1089,8 +1084,7 @@ class Analyzer(
case ae: AnalysisException => s case ae: AnalysisException => s
} }
case f @ Filter(cond, orgChild) if !f.resolved && orgChild.resolved => case f @ Filter(cond, child) if !f.resolved && child.resolved =>
val child = EliminateBarriers(orgChild)
try { try {
val newCond = resolveExpressionRecursively(cond, child) val newCond = resolveExpressionRecursively(cond, child)
val requiredAttrs = newCond.references.filter(_.resolved) val requiredAttrs = newCond.references.filter(_.resolved)
@ -1117,7 +1111,7 @@ class Analyzer(
*/ */
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) { if (missingAttrs.isEmpty) {
return AnalysisBarrier(plan) return plan
} }
plan match { plan match {
case p: Project => case p: Project =>
@ -1189,7 +1183,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/ */
object ResolveFunctions extends Rule[LogicalPlan] { object ResolveFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case q: LogicalPlan => case q: LogicalPlan =>
q transformExpressions { q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved. case u if !u.childrenResolved => u // Skip until children are resolved.
@ -1528,7 +1522,7 @@ class Analyzer(
/** /**
* Resolve and rewrite all subqueries in an operator tree.. * Resolve and rewrite all subqueries in an operator tree..
*/ */
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and // In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution. // its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved => case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
@ -1543,7 +1537,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations. * Turns projections that contain aggregate expressions into aggregations.
*/ */
object GlobalAggregates extends Rule[LogicalPlan] { object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) => case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child) Aggregate(Nil, projectList, child)
} }
@ -1569,9 +1563,7 @@ class Analyzer(
* underlying aggregate operator and then projected away after the original operator. * underlying aggregate operator and then projected away after the original operator.
*/ */
object ResolveAggregateFunctions extends Rule[LogicalPlan] { object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier)
case filter @ Filter(havingCondition, case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child)) aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved => if aggregate.resolved =>
@ -1631,8 +1623,6 @@ class Analyzer(
case ae: AnalysisException => filter case ae: AnalysisException => filter
} }
case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause. // Try resolving the ordering as though it is in the aggregate clause.
@ -1745,7 +1735,7 @@ class Analyzer(
} }
} }
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) => case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " + throw new AnalysisException("Generators are not supported when it's nested in " +
@ -1803,7 +1793,7 @@ class Analyzer(
* that wrap the [[Generator]]. * that wrap the [[Generator]].
*/ */
object ResolveGenerate extends Rule[LogicalPlan] { object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved => case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@ -2120,7 +2110,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project. * put them into an inner Project and finally project them away at the outer Project.
*/ */
object PullOutNondeterministic extends Rule[LogicalPlan] { object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes. case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p case p: Project => p
case f: Filter => f case f: Filter => f
@ -2165,7 +2155,7 @@ class Analyzer(
* and we should return null if the input is null. * and we should return null if the input is null.
*/ */
object HandleNullInputsForUDF extends Rule[LogicalPlan] { object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes. case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp { case p => p transformExpressionsUp {
@ -2230,7 +2220,7 @@ class Analyzer(
* Then apply a Project on a normal Join to eliminate natural or using join. * Then apply a Project on a normal Join to eliminate natural or using join.
*/ */
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved => if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None) commonNaturalJoinProcessing(left, right, joinType, usingCols, None)
@ -2295,7 +2285,7 @@ class Analyzer(
* to the given input attributes. * to the given input attributes.
*/ */
object ResolveDeserializer extends Rule[LogicalPlan] { object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p case p if !p.childrenResolved => p
case p if p.resolved => p case p if p.resolved => p
@ -2381,7 +2371,7 @@ class Analyzer(
* constructed is an inner class. * constructed is an inner class.
*/ */
object ResolveNewInstance extends Rule[LogicalPlan] { object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p case p if !p.childrenResolved => p
case p if p.resolved => p case p if p.resolved => p
@ -2415,7 +2405,7 @@ class Analyzer(
"type of the field in the target object") "type of the field in the target object")
} }
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p if !p.childrenResolved => p case p if !p.childrenResolved => p
case p if p.resolved => p case p if p.resolved => p
@ -2469,7 +2459,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
case other => trimAliases(other) case other => trimAliases(other)
} }
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) => case Project(projectList, child) =>
val cleanedProjectList = val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
@ -2498,13 +2488,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
} }
} }
/** Remove the barrier nodes of analysis */
object EliminateBarriers extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case AnalysisBarrier(child) => child
}
}
/** /**
* Ignore event time watermark in batch query, which is only supported in Structured Streaming. * Ignore event time watermark in batch query, which is only supported in Structured Streaming.
* TODO: add this rule into analyzer rule list. * TODO: add this rule into analyzer rule list.
@ -2554,7 +2537,7 @@ object TimeWindowing extends Rule[LogicalPlan] {
* @return the logical plan that will generate the time windows using the Expand operator, with * @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability. * the Filter operator for correctness and Project for usability.
*/ */
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case p: LogicalPlan if p.children.size == 1 => case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head val child = p.children.head
val windowExpressions = val windowExpressions =

View file

@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
PromotePrecision(Cast(e, dataType)) PromotePrecision(Cast(e, dataType))
} }
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions // fix decimal precision for expressions
case q => q.transformExpressions( case q => q.transformExpressions(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))

View file

@ -103,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
}) })
) )
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) => case Some(tvf) =>

View file

@ -206,7 +206,7 @@ object TypeCoercion {
* instances higher in the query tree. * instances higher in the query tree.
*/ */
object PropagateTypes extends Rule[LogicalPlan] { object PropagateTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// No propagation required for leaf nodes. // No propagation required for leaf nodes.
case q: LogicalPlan if q.children.isEmpty => q case q: LogicalPlan if q.children.isEmpty => q
@ -261,7 +261,7 @@ object TypeCoercion {
*/ */
object WidenSetOperationTypes extends Rule[LogicalPlan] { object WidenSetOperationTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if p.analyzed => p case p if p.analyzed => p
case s @ SetOperation(left, right) if s.childrenResolved && case s @ SetOperation(left, right) if s.childrenResolved &&
@ -335,7 +335,7 @@ object TypeCoercion {
} }
} }
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -391,7 +391,7 @@ object TypeCoercion {
} }
} }
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -449,7 +449,7 @@ object TypeCoercion {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE)
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO)
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -490,7 +490,7 @@ object TypeCoercion {
* This ensure that the types for various functions are as expected. * This ensure that the types for various functions are as expected.
*/ */
object FunctionArgumentConversion extends Rule[LogicalPlan] { object FunctionArgumentConversion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -580,7 +580,7 @@ object TypeCoercion {
* converted to fractional types. * converted to fractional types.
*/ */
object Division extends Rule[LogicalPlan] { object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet, // Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last. // as this is an extra rule which should be applied at last.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -602,7 +602,7 @@ object TypeCoercion {
* Coerces the type of different branches of a CASE WHEN statement to a common type. * Coerces the type of different branches of a CASE WHEN statement to a common type.
*/ */
object CaseWhenCoercion extends Rule[LogicalPlan] { object CaseWhenCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes) val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType => maybeCommonType.map { commonType =>
@ -632,7 +632,7 @@ object TypeCoercion {
* Coerces the type of different branches of If statement to a common type. * Coerces the type of different branches of If statement to a common type.
*/ */
object IfCoercion extends Rule[LogicalPlan] { object IfCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types. // Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType => case i @ If(pred, left, right) if left.dataType != right.dataType =>
@ -656,7 +656,7 @@ object TypeCoercion {
private val acceptedTypes = Seq(DateType, TimestampType, StringType) private val acceptedTypes = Seq(DateType, TimestampType, StringType)
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e
@ -673,7 +673,7 @@ object TypeCoercion {
* Casts types according to the expected input types for [[Expression]]s. * Casts types according to the expected input types for [[Expression]]s.
*/ */
object ImplicitTypeCasts extends Rule[LogicalPlan] { object ImplicitTypeCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet. // Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e case e if !e.childrenResolved => e

View file

@ -38,7 +38,7 @@ case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
} }
override def apply(plan: LogicalPlan): LogicalPlan = override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformAllExpressions(transformTimeZoneExprs) plan.resolveExpressions(transformTimeZoneExprs)
def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
} }

View file

@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
* completely resolved during the batch of Resolution. * completely resolved during the batch of Resolution.
*/ */
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case v @ View(desc, output, child) if child.resolved && output != child.output => case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver val resolver = conf.resolver
val queryColumnNames = desc.viewQueryColumnNames val queryColumnNames = desc.viewQueryColumnNames

View file

@ -236,7 +236,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
/** /**
* Pull up the correlated predicates and rewrite all subqueries in an operator tree.. * Pull up the correlated predicates and rewrite all subqueries in an operator tree..
*/ */
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case f @ Filter(_, a: Aggregate) => case f @ Filter(_, a: Aggregate) =>
rewriteSubQueries(f, Seq(a, a.child)) rewriteSubQueries(f, Seq(a, a.child))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.

View file

@ -46,6 +46,41 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/** Returns true if this subtree contains any streaming data sources. */ /** Returns true if this subtree contains any streaming data sources. */
def isStreaming: Boolean = children.exists(_.isStreaming == true) def isStreaming: Boolean = children.exists(_.isStreaming == true)
/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
* unchanged. This function is similar to `transformUp`, but skips sub-trees that have already
* been marked as analyzed.
*
* @param rule the function use to transform this nodes children
*/
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
if (!analyzed) {
val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[LogicalPlan])
}
} else {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
}
}
} else {
this
}
}
/**
* Recursively transforms the expressions of a tree, skipping nodes that have already
* been analyzed.
*/
def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = {
this resolveOperators {
case p => p.transformExpressions(r)
}
}
/** A cache for the estimated statistics, such that it will only be computed once. */ /** A cache for the estimated statistics, such that it will only be computed once. */
private var statsCache: Option[Statistics] = None private var statsCache: Option[Statistics] = None

View file

@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -897,11 +896,3 @@ case class Deduplicate(
override def output: Seq[Attribute] = child.output override def output: Seq[Attribute] = child.output
} }
/** A logical plan for setting a barrier of analysis */
case class AnalysisBarrier(child: LogicalPlan) extends LeafNode {
override def output: Seq[Attribute] = child.output
override def analyzed: Boolean = true
override def isStreaming: Boolean = child.isStreaming
override lazy val canonicalized: LogicalPlan = child.canonicalized
}

View file

@ -441,20 +441,6 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
} }
test("analysis barrier") {
// [[AnalysisBarrier]] will be removed after analysis
checkAnalysis(
Project(Seq(UnresolvedAttribute("tbl.a")),
AnalysisBarrier(SubqueryAlias("tbl", testRelation))),
Project(testRelation.output, SubqueryAlias("tbl", testRelation)))
// Verify we won't go through a plan wrapped in a barrier.
// Since we wrap an unresolved plan and analyzer won't go through it. It remains unresolved.
val barrier = AnalysisBarrier(Project(Seq(UnresolvedAttribute("tbl.b")),
SubqueryAlias("tbl", testRelation)))
assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'"))
}
test("SPARK-20311 range(N) as alias") { test("SPARK-20311 range(N) as alias") {
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))

View file

@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.IntegerType
/** /**
* This suite is used to test [[LogicalPlan]]'s `transformUp` plus analysis barrier and make sure * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
* it can correctly skip sub-trees that have already been marked as analyzed. * skips sub-trees that have already been marked as analyzed.
*/ */
class LogicalPlanSuite extends SparkFunSuite { class LogicalPlanSuite extends SparkFunSuite {
private var invocationCount = 0 private var invocationCount = 0
@ -36,35 +36,37 @@ class LogicalPlanSuite extends SparkFunSuite {
private val testRelation = LocalRelation() private val testRelation = LocalRelation()
test("transformUp runs on operators") { test("resolveOperator runs on operators") {
invocationCount = 0 invocationCount = 0
val plan = Project(Nil, testRelation) val plan = Project(Nil, testRelation)
plan transformUp function plan resolveOperators function
assert(invocationCount === 1) assert(invocationCount === 1)
} }
test("transformUp runs on operators recursively") { test("resolveOperator runs on operators recursively") {
invocationCount = 0 invocationCount = 0
val plan = Project(Nil, Project(Nil, testRelation)) val plan = Project(Nil, Project(Nil, testRelation))
plan transformUp function plan resolveOperators function
assert(invocationCount === 2) assert(invocationCount === 2)
} }
test("transformUp skips all ready resolved plans wrapped in analysis barrier") { test("resolveOperator skips all ready resolved plans") {
invocationCount = 0 invocationCount = 0
val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation))) val plan = Project(Nil, Project(Nil, testRelation))
plan transformUp function plan.foreach(_.setAnalyzed())
plan resolveOperators function
assert(invocationCount === 0) assert(invocationCount === 0)
} }
test("transformUp skips partially resolved plans wrapped in analysis barrier") { test("resolveOperator skips partially resolved plans") {
invocationCount = 0 invocationCount = 0
val plan1 = AnalysisBarrier(Project(Nil, testRelation)) val plan1 = Project(Nil, testRelation)
val plan2 = Project(Nil, plan1) val plan2 = Project(Nil, plan1)
plan2 transformUp function plan1.foreach(_.setAnalyzed())
plan2 resolveOperators function
assert(invocationCount === 1) assert(invocationCount === 1)
} }

View file

@ -187,9 +187,6 @@ class Dataset[T] private[sql](
} }
} }
// Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again.
@transient private val planWithBarrier = AnalysisBarrier(logicalPlan)
/** /**
* Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
* passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
@ -421,7 +418,7 @@ class Dataset[T] private[sql](
*/ */
@Experimental @Experimental
@InterfaceStability.Evolving @InterfaceStability.Evolving
def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, planWithBarrier) def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
/** /**
* Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
@ -624,7 +621,7 @@ class Dataset[T] private[sql](
require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0,
s"delay threshold ($delayThreshold) should not be negative.") s"delay threshold ($delayThreshold) should not be negative.")
EliminateEventTimeWatermark( EliminateEventTimeWatermark(
EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, planWithBarrier)) EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan))
} }
/** /**
@ -810,7 +807,7 @@ class Dataset[T] private[sql](
* @since 2.0.0 * @since 2.0.0
*/ */
def join(right: Dataset[_]): DataFrame = withPlan { def join(right: Dataset[_]): DataFrame = withPlan {
Join(planWithBarrier, right.planWithBarrier, joinType = Inner, None) Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
} }
/** /**
@ -888,7 +885,7 @@ class Dataset[T] private[sql](
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch. // by creating a new instance for one of the branch.
val joined = sparkSession.sessionState.executePlan( val joined = sparkSession.sessionState.executePlan(
Join(planWithBarrier, right.planWithBarrier, joinType = JoinType(joinType), None)) Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None))
.analyzed.asInstanceOf[Join] .analyzed.asInstanceOf[Join]
withPlan { withPlan {
@ -949,7 +946,7 @@ class Dataset[T] private[sql](
// Trigger analysis so in the case of self-join, the analyzer will clone the plan. // Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids. // After the cloning, left and right side will have distinct expression ids.
val plan = withPlan( val plan = withPlan(
Join(planWithBarrier, right.planWithBarrier, JoinType(joinType), Some(joinExprs.expr))) Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)))
.queryExecution.analyzed.asInstanceOf[Join] .queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan. // If auto self join alias is disabled, return the plan.
@ -958,8 +955,8 @@ class Dataset[T] private[sql](
} }
// If left/right have no output set intersection, return the plan. // If left/right have no output set intersection, return the plan.
val lanalyzed = withPlan(this.planWithBarrier).queryExecution.analyzed val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
val ranalyzed = withPlan(right.planWithBarrier).queryExecution.analyzed val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan) return withPlan(plan)
} }
@ -991,7 +988,7 @@ class Dataset[T] private[sql](
* @since 2.1.0 * @since 2.1.0
*/ */
def crossJoin(right: Dataset[_]): DataFrame = withPlan { def crossJoin(right: Dataset[_]): DataFrame = withPlan {
Join(planWithBarrier, right.planWithBarrier, joinType = Cross, None) Join(logicalPlan, right.logicalPlan, joinType = Cross, None)
} }
/** /**
@ -1023,8 +1020,8 @@ class Dataset[T] private[sql](
// etc. // etc.
val joined = sparkSession.sessionState.executePlan( val joined = sparkSession.sessionState.executePlan(
Join( Join(
this.planWithBarrier, this.logicalPlan,
other.planWithBarrier, other.logicalPlan,
JoinType(joinType), JoinType(joinType),
Some(condition.expr))).analyzed.asInstanceOf[Join] Some(condition.expr))).analyzed.asInstanceOf[Join]
@ -1194,7 +1191,7 @@ class Dataset[T] private[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan {
UnresolvedHint(name, parameters, planWithBarrier) UnresolvedHint(name, parameters, logicalPlan)
} }
/** /**
@ -1220,7 +1217,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def as(alias: String): Dataset[T] = withTypedPlan { def as(alias: String): Dataset[T] = withTypedPlan {
SubqueryAlias(alias, planWithBarrier) SubqueryAlias(alias, logicalPlan)
} }
/** /**
@ -1258,7 +1255,7 @@ class Dataset[T] private[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan { def select(cols: Column*): DataFrame = withPlan {
Project(cols.map(_.named), planWithBarrier) Project(cols.map(_.named), logicalPlan)
} }
/** /**
@ -1313,8 +1310,8 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving @InterfaceStability.Evolving
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, planWithBarrier.output).named :: Nil, val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
planWithBarrier) logicalPlan)
if (encoder.flat) { if (encoder.flat) {
new Dataset[U1](sparkSession, project, encoder) new Dataset[U1](sparkSession, project, encoder)
@ -1332,8 +1329,8 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder) val encoders = columns.map(_.encoder)
val namedColumns = val namedColumns =
columns.map(_.withInputType(exprEnc, planWithBarrier.output).named) columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
val execution = new QueryExecution(sparkSession, Project(namedColumns, planWithBarrier)) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
} }
@ -1409,7 +1406,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def filter(condition: Column): Dataset[T] = withTypedPlan { def filter(condition: Column): Dataset[T] = withTypedPlan {
Filter(condition.expr, planWithBarrier) Filter(condition.expr, logicalPlan)
} }
/** /**
@ -1586,7 +1583,7 @@ class Dataset[T] private[sql](
@Experimental @Experimental
@InterfaceStability.Evolving @InterfaceStability.Evolving
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val inputPlan = planWithBarrier val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, inputPlan) val withGroupingKey = AppendColumns(func, inputPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey) val executed = sparkSession.sessionState.executePlan(withGroupingKey)
@ -1732,7 +1729,7 @@ class Dataset[T] private[sql](
* @since 2.0.0 * @since 2.0.0
*/ */
def limit(n: Int): Dataset[T] = withTypedPlan { def limit(n: Int): Dataset[T] = withTypedPlan {
Limit(Literal(n), planWithBarrier) Limit(Literal(n), logicalPlan)
} }
/** /**
@ -1761,7 +1758,7 @@ class Dataset[T] private[sql](
def union(other: Dataset[T]): Dataset[T] = withSetOperator { def union(other: Dataset[T]): Dataset[T] = withSetOperator {
// This breaks caching, but it's usually ok because it addresses a very specific use case: // This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions. // using union to union many files or partitions.
CombineUnions(Union(logicalPlan, other.logicalPlan)).mapChildren(AnalysisBarrier) CombineUnions(Union(logicalPlan, other.logicalPlan))
} }
/** /**
@ -1775,7 +1772,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { def intersect(other: Dataset[T]): Dataset[T] = withSetOperator {
Intersect(planWithBarrier, other.planWithBarrier) Intersect(logicalPlan, other.logicalPlan)
} }
/** /**
@ -1789,7 +1786,7 @@ class Dataset[T] private[sql](
* @since 2.0.0 * @since 2.0.0
*/ */
def except(other: Dataset[T]): Dataset[T] = withSetOperator { def except(other: Dataset[T]): Dataset[T] = withSetOperator {
Except(planWithBarrier, other.planWithBarrier) Except(logicalPlan, other.logicalPlan)
} }
/** /**
@ -1810,7 +1807,7 @@ class Dataset[T] private[sql](
s"Fraction must be nonnegative, but got ${fraction}") s"Fraction must be nonnegative, but got ${fraction}")
withTypedPlan { withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, planWithBarrier)() Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
} }
} }
@ -1852,15 +1849,15 @@ class Dataset[T] private[sql](
// overlapping splits. To prevent this, we explicitly sort each input partition to make the // overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
// from the sort order. // from the sort order.
val sortOrder = planWithBarrier.output val sortOrder = logicalPlan.output
.filter(attr => RowOrdering.isOrderable(attr.dataType)) .filter(attr => RowOrdering.isOrderable(attr.dataType))
.map(SortOrder(_, Ascending)) .map(SortOrder(_, Ascending))
val plan = if (sortOrder.nonEmpty) { val plan = if (sortOrder.nonEmpty) {
Sort(sortOrder, global = false, planWithBarrier) Sort(sortOrder, global = false, logicalPlan)
} else { } else {
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
cache() cache()
planWithBarrier logicalPlan
} }
val sum = weights.sum val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
@ -1944,7 +1941,7 @@ class Dataset[T] private[sql](
withPlan { withPlan {
Generate(generator, join = true, outer = false, Generate(generator, join = true, outer = false,
qualifier = None, generatorOutput = Nil, planWithBarrier) qualifier = None, generatorOutput = Nil, logicalPlan)
} }
} }
@ -1985,7 +1982,7 @@ class Dataset[T] private[sql](
withPlan { withPlan {
Generate(generator, join = true, outer = false, Generate(generator, join = true, outer = false,
qualifier = None, generatorOutput = Nil, planWithBarrier) qualifier = None, generatorOutput = Nil, logicalPlan)
} }
} }
@ -2100,7 +2097,7 @@ class Dataset[T] private[sql](
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr case Column(expr: Expression) => expr
} }
val attrs = this.planWithBarrier.output val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr => val colsAfterDrop = attrs.filter { attr =>
attr != expression attr != expression
}.map(attr => Column(attr)) }.map(attr => Column(attr))
@ -2148,7 +2145,7 @@ class Dataset[T] private[sql](
} }
cols cols
} }
Deduplicate(groupCols, planWithBarrier, isStreaming) Deduplicate(groupCols, logicalPlan, isStreaming)
} }
/** /**
@ -2297,7 +2294,7 @@ class Dataset[T] private[sql](
@Experimental @Experimental
@InterfaceStability.Evolving @InterfaceStability.Evolving
def filter(func: T => Boolean): Dataset[T] = { def filter(func: T => Boolean): Dataset[T] = {
withTypedPlan(TypedFilter(func, planWithBarrier)) withTypedPlan(TypedFilter(func, logicalPlan))
} }
/** /**
@ -2311,7 +2308,7 @@ class Dataset[T] private[sql](
@Experimental @Experimental
@InterfaceStability.Evolving @InterfaceStability.Evolving
def filter(func: FilterFunction[T]): Dataset[T] = { def filter(func: FilterFunction[T]): Dataset[T] = {
withTypedPlan(TypedFilter(func, planWithBarrier)) withTypedPlan(TypedFilter(func, logicalPlan))
} }
/** /**
@ -2325,7 +2322,7 @@ class Dataset[T] private[sql](
@Experimental @Experimental
@InterfaceStability.Evolving @InterfaceStability.Evolving
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
MapElements[T, U](func, planWithBarrier) MapElements[T, U](func, logicalPlan)
} }
/** /**
@ -2340,7 +2337,7 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving @InterfaceStability.Evolving
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
implicit val uEnc = encoder implicit val uEnc = encoder
withTypedPlan(MapElements[T, U](func, planWithBarrier)) withTypedPlan(MapElements[T, U](func, logicalPlan))
} }
/** /**
@ -2356,7 +2353,7 @@ class Dataset[T] private[sql](
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U]( new Dataset[U](
sparkSession, sparkSession,
MapPartitions[T, U](func, planWithBarrier), MapPartitions[T, U](func, logicalPlan),
implicitly[Encoder[U]]) implicitly[Encoder[U]])
} }
@ -2387,7 +2384,7 @@ class Dataset[T] private[sql](
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]] val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
Dataset.ofRows( Dataset.ofRows(
sparkSession, sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, planWithBarrier)) MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
} }
/** /**
@ -2557,7 +2554,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = true, planWithBarrier) Repartition(numPartitions, shuffle = true, logicalPlan)
} }
/** /**
@ -2571,7 +2568,7 @@ class Dataset[T] private[sql](
*/ */
@scala.annotation.varargs @scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), planWithBarrier, numPartitions) RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions)
} }
/** /**
@ -2587,8 +2584,7 @@ class Dataset[T] private[sql](
@scala.annotation.varargs @scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression( RepartitionByExpression(
partitionExprs.map(_.expr), planWithBarrier, partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions)
sparkSession.sessionState.conf.numShufflePartitions)
} }
/** /**
@ -2609,7 +2605,7 @@ class Dataset[T] private[sql](
* @since 1.6.0 * @since 1.6.0
*/ */
def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = false, planWithBarrier) Repartition(numPartitions, shuffle = false, logicalPlan)
} }
/** /**
@ -2698,7 +2694,7 @@ class Dataset[T] private[sql](
*/ */
lazy val rdd: RDD[T] = { lazy val rdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType val objectType = exprEnc.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](planWithBarrier) val deserialized = CatalystSerde.deserialize[T](logicalPlan)
sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType).asInstanceOf[T]) rows.map(_.get(0, objectType).asInstanceOf[T])
} }
@ -2812,7 +2808,7 @@ class Dataset[T] private[sql](
comment = None, comment = None,
properties = Map.empty, properties = Map.empty,
originalText = None, originalText = None,
child = planWithBarrier, child = logicalPlan,
allowExisting = false, allowExisting = false,
replace = replace, replace = replace,
viewType = viewType) viewType = viewType)
@ -2981,7 +2977,7 @@ class Dataset[T] private[sql](
} }
} }
withTypedPlan { withTypedPlan {
Sort(sortOrder, global = global, planWithBarrier) Sort(sortOrder, global = global, logicalPlan)
} }
} }

View file

@ -416,7 +416,7 @@ case class DataSource(
}.head }.head
} }
// For partitioned relation r, r.schema's column ordering can be different from the column // For partitioned relation r, r.schema's column ordering can be different from the column
// ordering of data.logicalPlan (partition columns are all moved after data column). This // ordering of data.logicalPlan (partition columns are all moved after data column). This
// will be adjusted within InsertIntoHadoopFsRelation. // will be adjusted within InsertIntoHadoopFsRelation.
InsertIntoHadoopFsRelationCommand( InsertIntoHadoopFsRelationCommand(
outputPath = outputPath, outputPath = outputPath,

View file

@ -38,7 +38,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined
} }
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedRelation if maybeSQLFile(u) => case u: UnresolvedRelation if maybeSQLFile(u) =>
try { try {
val dataSource = DataSource( val dataSource = DataSource(

View file

@ -241,7 +241,7 @@ class PlannerSuite extends SharedSQLContext {
test("collapse adjacent repartitions") { test("collapse adjacent repartitions") {
val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
assert(countRepartitions(doubleRepartitioned.queryExecution.analyzed) === 3) assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2)
doubleRepartitioned.queryExecution.optimizedPlan match { doubleRepartitioned.queryExecution.optimizedPlan match {
case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) =>

View file

@ -88,7 +88,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] {
} }
} }
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) => case c @ CreateTable(t, _, query) if DDLUtils.isHiveTable(t) =>
// Finds the database name if the name does not exist. // Finds the database name if the name does not exist.
val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase) val dbName = t.identifier.database.getOrElse(session.catalog.currentDatabase)
@ -115,7 +115,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] {
} }
class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case relation: CatalogRelation case relation: CatalogRelation
if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty =>
val table = relation.tableMeta val table = relation.tableMeta
@ -146,7 +146,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
* `PreprocessTableInsertion`. * `PreprocessTableInsertion`.
*/ */
object HiveAnalysis extends Rule[LogicalPlan] { object HiveAnalysis extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists) case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists)
if DDLUtils.isHiveTable(r.tableMeta) => if DDLUtils.isHiveTable(r.tableMeta) =>
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)