[SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator

## What changes were proposed in this pull request?

This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12131 from cloud-fan/separate.
This commit is contained in:
Wenchen Fan 2016-04-05 10:53:54 -07:00 committed by Michael Armbrust
parent e4bd504120
commit f77f11c671
5 changed files with 153 additions and 126 deletions

View file

@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
import java.lang.reflect.Modifier
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
@ -87,9 +85,11 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveGenerate ::
@ -499,18 +499,9 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
// A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
// should be resolved by their corresponding attributes instead of children's output.
case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
val deserializerToAttributes = o.deserializers.map {
case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
}.toMap
o.transformExpressions {
case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
resolveDeserializer(expr, attributes)
}.getOrElse(expr)
}
// Skips plan which contains deserializer expressions, as they should be resolved by another
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
@ -526,38 +517,6 @@ class Analyzer(
}
}
private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
!expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
}
}
def resolveDeserializer(
deserializer: Expression,
attributes: Seq[Attribute]): Expression = {
val unbound = deserializer transform {
case b: BoundReference => attributes(b.ordinal)
}
resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
case n: NewInstance
// If this is an inner class of another class, register the outer object in `OuterScopes`.
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
// outer pointer registration.
if n.outerPointer.isEmpty &&
n.cls.isMemberClass &&
!Modifier.isStatic(n.cls.getModifiers) =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
"access to the scope that this class was defined in.\n" +
"Try moving this class out of its parent class.")
}
n.copy(outerPointer = Some(outer))
}
}
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
@ -623,6 +582,10 @@ class Analyzer(
}
}
private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
@ -1475,7 +1438,94 @@ class Analyzer(
Project(projectList, Join(left, right, joinType, newCondition))
}
/**
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case UnresolvedDeserializer(deserializer, inputAttributes) =>
val inputs = if (inputAttributes.isEmpty) {
p.children.flatMap(_.output)
} else {
inputAttributes
}
val unbound = deserializer transform {
case b: BoundReference => inputs(b.ordinal)
}
resolveExpression(unbound, LocalRelation(inputs), throws = true)
}
}
}
/**
* Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case n: NewInstance if n.childrenResolved && !n.resolved =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
"access to the scope that this class was defined in.\n" +
"Try moving this class out of its parent class.")
}
n.copy(outerPointer = Some(outer))
}
}
}
/**
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
fail(child, to, walkedTypePath)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
fail(child, to, walkedTypePath)
case (from, to) if illegalNumericPrecedence(from, to) =>
fail(child, to, walkedTypePath)
case (TimestampType, DateType) =>
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType.asNullable)
}
}
}
}
}
/**
@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
toPrecedence > 0 && fromPrecedence > toPrecedence
}
def apply(plan: LogicalPlan): LogicalPlan = {
plan transformAllExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
fail(child, to, walkedTypePath)
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
fail(child, to, walkedTypePath)
case (from, to) if illegalNumericPrecedence(from, to) =>
fail(child, to, walkedTypePath)
case (TimestampType, DateType) =>
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType.asNullable)
}
}
}
}
/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and

View file

@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override lazy val resolved = false
}
/**
* Holds the deserializer expression and the attributes that are available during the resolution
* for it. Deserializer expression is a special kind of expression that is not always resolved by
* children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
* resolved by `groupingAttributes` instead of children output.
*
* @param deserializer The unresolved deserializer expression
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
* if we want to resolve deserializer by children output.
*/
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
extends UnaryExpression with Unevaluable with NonSQLExpression {
// The input attributes used to resolve deserializer expression must be all resolved.
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
override def child: Expression = deserializer
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}

View file

@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
@ -317,11 +317,11 @@ case class ExpressionEncoder[T](
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
val plan = Project(
Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import java.lang.reflect.Modifier
import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
@ -112,7 +114,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
override def children: Seq[Expression] = arguments.+:(targetObject)
override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@ -214,6 +216,16 @@ case class NewInstance(
override def children: Seq[Expression] = arguments
override lazy val resolved: Boolean = {
// If the class to construct is an inner class, we need to get its outer pointer, or this
// expression should be regarded as unresolved.
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
// outer pointer registration.
val needOuterPointer =
outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
childrenResolved && !needOuterPointer
}
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ObjectType, StructType}
@ -32,13 +33,6 @@ trait ObjectOperator extends LogicalPlan {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
/**
* An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
* It must also provide the attributes that are available during the resolution of each
* deserializer.
*/
def deserializers: Seq[(Expression, Seq[Attribute])]
/**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
@ -71,7 +65,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
encoderFor[T].deserializer,
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
encoderFor[U].namedExpressions,
child)
}
@ -87,9 +81,7 @@ case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
@ -98,7 +90,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
encoderFor[T].deserializer,
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
encoderFor[U].namedExpressions,
child)
}
@ -120,8 +112,6 @@ case class AppendColumns(
override def output: Seq[Attribute] = child.output ++ newColumns
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@ -133,8 +123,8 @@ object MapGroups {
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
encoderFor[K].deserializer,
encoderFor[T].deserializer,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
@ -158,11 +148,7 @@ case class MapGroups(
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
child: LogicalPlan) extends UnaryNode with ObjectOperator {
override def deserializers: Seq[(Expression, Seq[Attribute])] =
Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
}
child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
@ -170,22 +156,24 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftData: Seq[Attribute],
rightData: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
encoderFor[Key].deserializer,
encoderFor[Left].deserializer,
encoderFor[Right].deserializer,
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
// resolve the `keyDeserializer` based on either of them, here we pick the left one.
UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
leftData,
rightData,
leftAttr,
rightAttr,
left,
right)
}
@ -206,10 +194,4 @@ case class CoGroup(
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectOperator {
override def deserializers: Seq[(Expression, Seq[Attribute])] =
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
// the `keyDeserializer` based on either of them, here we pick the left one.
Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
}
right: LogicalPlan) extends BinaryNode with ObjectOperator