[SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator
## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12131 from cloud-fan/separate.
This commit is contained in:
parent
e4bd504120
commit
f77f11c671
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import java.lang.reflect.Modifier
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
@ -87,9 +85,11 @@ class Analyzer(
|
|||
Batch("Resolution", fixedPoint,
|
||||
ResolveRelations ::
|
||||
ResolveReferences ::
|
||||
ResolveDeserializer ::
|
||||
ResolveNewInstance ::
|
||||
ResolveUpCast ::
|
||||
ResolveGroupingAnalytics ::
|
||||
ResolvePivot ::
|
||||
ResolveUpCast ::
|
||||
ResolveOrdinalInOrderByAndGroupBy ::
|
||||
ResolveSortReferences ::
|
||||
ResolveGenerate ::
|
||||
|
@ -499,18 +499,9 @@ class Analyzer(
|
|||
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
|
||||
}
|
||||
|
||||
// A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
|
||||
// should be resolved by their corresponding attributes instead of children's output.
|
||||
case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
|
||||
val deserializerToAttributes = o.deserializers.map {
|
||||
case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
|
||||
}.toMap
|
||||
|
||||
o.transformExpressions {
|
||||
case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
|
||||
resolveDeserializer(expr, attributes)
|
||||
}.getOrElse(expr)
|
||||
}
|
||||
// Skips plan which contains deserializer expressions, as they should be resolved by another
|
||||
// rule: ResolveDeserializer.
|
||||
case plan if containsDeserializer(plan.expressions) => plan
|
||||
|
||||
case q: LogicalPlan =>
|
||||
logTrace(s"Attempting to resolve ${q.simpleString}")
|
||||
|
@ -526,38 +517,6 @@ class Analyzer(
|
|||
}
|
||||
}
|
||||
|
||||
private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
|
||||
exprs.exists { expr =>
|
||||
!expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
|
||||
}
|
||||
}
|
||||
|
||||
def resolveDeserializer(
|
||||
deserializer: Expression,
|
||||
attributes: Seq[Attribute]): Expression = {
|
||||
val unbound = deserializer transform {
|
||||
case b: BoundReference => attributes(b.ordinal)
|
||||
}
|
||||
|
||||
resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
|
||||
case n: NewInstance
|
||||
// If this is an inner class of another class, register the outer object in `OuterScopes`.
|
||||
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
|
||||
// outer pointer registration.
|
||||
if n.outerPointer.isEmpty &&
|
||||
n.cls.isMemberClass &&
|
||||
!Modifier.isStatic(n.cls.getModifiers) =>
|
||||
val outer = OuterScopes.getOuterScope(n.cls)
|
||||
if (outer == null) {
|
||||
throw new AnalysisException(
|
||||
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
|
||||
"access to the scope that this class was defined in.\n" +
|
||||
"Try moving this class out of its parent class.")
|
||||
}
|
||||
n.copy(outerPointer = Some(outer))
|
||||
}
|
||||
}
|
||||
|
||||
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
|
||||
expressions.map {
|
||||
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
|
||||
|
@ -623,6 +582,10 @@ class Analyzer(
|
|||
}
|
||||
}
|
||||
|
||||
private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
|
||||
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
|
||||
}
|
||||
|
||||
protected[sql] def resolveExpression(
|
||||
expr: Expression,
|
||||
plan: LogicalPlan,
|
||||
|
@ -1475,7 +1438,94 @@ class Analyzer(
|
|||
Project(projectList, Join(left, right, joinType, newCondition))
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
|
||||
* to the given input attributes.
|
||||
*/
|
||||
object ResolveDeserializer extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
case p => p transformExpressions {
|
||||
case UnresolvedDeserializer(deserializer, inputAttributes) =>
|
||||
val inputs = if (inputAttributes.isEmpty) {
|
||||
p.children.flatMap(_.output)
|
||||
} else {
|
||||
inputAttributes
|
||||
}
|
||||
val unbound = deserializer transform {
|
||||
case b: BoundReference => inputs(b.ordinal)
|
||||
}
|
||||
resolveExpression(unbound, LocalRelation(inputs), throws = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
|
||||
* constructed is an inner class.
|
||||
*/
|
||||
object ResolveNewInstance extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
case p => p transformExpressions {
|
||||
case n: NewInstance if n.childrenResolved && !n.resolved =>
|
||||
val outer = OuterScopes.getOuterScope(n.cls)
|
||||
if (outer == null) {
|
||||
throw new AnalysisException(
|
||||
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
|
||||
"access to the scope that this class was defined in.\n" +
|
||||
"Try moving this class out of its parent class.")
|
||||
}
|
||||
n.copy(outerPointer = Some(outer))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
|
||||
*/
|
||||
object ResolveUpCast extends Rule[LogicalPlan] {
|
||||
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
|
||||
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
|
||||
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
|
||||
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
|
||||
"You can either add an explicit cast to the input data or choose a higher precision " +
|
||||
"type of the field in the target object")
|
||||
}
|
||||
|
||||
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
|
||||
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
|
||||
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
|
||||
toPrecedence > 0 && fromPrecedence > toPrecedence
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
case p if !p.childrenResolved => p
|
||||
case p if p.resolved => p
|
||||
|
||||
case p => p transformExpressions {
|
||||
case u @ UpCast(child, _, _) if !child.resolved => u
|
||||
|
||||
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
|
||||
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (from, to) if illegalNumericPrecedence(from, to) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (TimestampType, DateType) =>
|
||||
fail(child, DateType, walkedTypePath)
|
||||
case (StringType, to: NumericType) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case _ => Cast(child, dataType.asNullable)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
|
||||
*/
|
||||
object ResolveUpCast extends Rule[LogicalPlan] {
|
||||
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
|
||||
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
|
||||
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
|
||||
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
|
||||
"You can either add an explicit cast to the input data or choose a higher precision " +
|
||||
"type of the field in the target object")
|
||||
}
|
||||
|
||||
private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
|
||||
val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
|
||||
val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
|
||||
toPrecedence > 0 && fromPrecedence > toPrecedence
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
plan transformAllExpressions {
|
||||
case u @ UpCast(child, _, _) if !child.resolved => u
|
||||
|
||||
case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
|
||||
case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (from, to) if illegalNumericPrecedence(from, to) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case (TimestampType, DateType) =>
|
||||
fail(child, DateType, walkedTypePath)
|
||||
case (StringType, to: NumericType) =>
|
||||
fail(child, to, walkedTypePath)
|
||||
case _ => Cast(child, dataType.asNullable)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
|
||||
* figure out how many windows a time column can map to, we over-estimate the number of windows and
|
||||
|
|
|
@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
|
|||
|
||||
override lazy val resolved = false
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds the deserializer expression and the attributes that are available during the resolution
|
||||
* for it. Deserializer expression is a special kind of expression that is not always resolved by
|
||||
* children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
|
||||
* resolved by `groupingAttributes` instead of children output.
|
||||
*
|
||||
* @param deserializer The unresolved deserializer expression
|
||||
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
|
||||
* if we want to resolve deserializer by children output.
|
||||
*/
|
||||
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
|
||||
extends UnaryExpression with Unevaluable with NonSQLExpression {
|
||||
// The input attributes used to resolve deserializer expression must be all resolved.
|
||||
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
|
||||
|
||||
override def child: Expression = deserializer
|
||||
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
|
||||
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
|
||||
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
|
||||
override lazy val resolved = false
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
|
|||
|
||||
import org.apache.spark.sql.{AnalysisException, Encoder}
|
||||
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
|
||||
|
@ -317,11 +317,11 @@ case class ExpressionEncoder[T](
|
|||
def resolve(
|
||||
schema: Seq[Attribute],
|
||||
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
|
||||
val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)
|
||||
|
||||
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
|
||||
// analysis, go through optimizer, etc.
|
||||
val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
|
||||
val plan = Project(
|
||||
Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
|
||||
LocalRelation(schema))
|
||||
val analyzedPlan = SimpleAnalyzer.execute(plan)
|
||||
SimpleAnalyzer.checkAnalysis(analyzedPlan)
|
||||
copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.lang.reflect.Modifier
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.language.existentials
|
||||
import scala.reflect.ClassTag
|
||||
|
@ -112,7 +114,7 @@ case class Invoke(
|
|||
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def children: Seq[Expression] = arguments.+:(targetObject)
|
||||
override def children: Seq[Expression] = targetObject +: arguments
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
@ -214,6 +216,16 @@ case class NewInstance(
|
|||
|
||||
override def children: Seq[Expression] = arguments
|
||||
|
||||
override lazy val resolved: Boolean = {
|
||||
// If the class to construct is an inner class, we need to get its outer pointer, or this
|
||||
// expression should be regarded as unresolved.
|
||||
// Note that static inner classes (e.g., inner classes within Scala objects) don't need
|
||||
// outer pointer registration.
|
||||
val needOuterPointer =
|
||||
outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
|
||||
childrenResolved && !needOuterPointer
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{ObjectType, StructType}
|
||||
|
@ -32,13 +33,6 @@ trait ObjectOperator extends LogicalPlan {
|
|||
|
||||
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
|
||||
|
||||
/**
|
||||
* An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
|
||||
* It must also provide the attributes that are available during the resolution of each
|
||||
* deserializer.
|
||||
*/
|
||||
def deserializers: Seq[(Expression, Seq[Attribute])]
|
||||
|
||||
/**
|
||||
* The object type that is produced by the user defined function. Note that the return type here
|
||||
* is the same whether or not the operator is output serialized data.
|
||||
|
@ -71,7 +65,7 @@ object MapPartitions {
|
|||
child: LogicalPlan): MapPartitions = {
|
||||
MapPartitions(
|
||||
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
|
||||
encoderFor[T].deserializer,
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
|
||||
encoderFor[U].namedExpressions,
|
||||
child)
|
||||
}
|
||||
|
@ -87,9 +81,7 @@ case class MapPartitions(
|
|||
func: Iterator[Any] => Iterator[Any],
|
||||
deserializer: Expression,
|
||||
serializer: Seq[NamedExpression],
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator {
|
||||
override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
|
||||
}
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator
|
||||
|
||||
/** Factory for constructing new `AppendColumn` nodes. */
|
||||
object AppendColumns {
|
||||
|
@ -98,7 +90,7 @@ object AppendColumns {
|
|||
child: LogicalPlan): AppendColumns = {
|
||||
new AppendColumns(
|
||||
func.asInstanceOf[Any => Any],
|
||||
encoderFor[T].deserializer,
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
|
||||
encoderFor[U].namedExpressions,
|
||||
child)
|
||||
}
|
||||
|
@ -120,8 +112,6 @@ case class AppendColumns(
|
|||
override def output: Seq[Attribute] = child.output ++ newColumns
|
||||
|
||||
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
|
||||
|
||||
override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
|
||||
}
|
||||
|
||||
/** Factory for constructing new `MapGroups` nodes. */
|
||||
|
@ -133,8 +123,8 @@ object MapGroups {
|
|||
child: LogicalPlan): MapGroups = {
|
||||
new MapGroups(
|
||||
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
|
||||
encoderFor[K].deserializer,
|
||||
encoderFor[T].deserializer,
|
||||
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
|
||||
encoderFor[U].namedExpressions,
|
||||
groupingAttributes,
|
||||
dataAttributes,
|
||||
|
@ -158,11 +148,7 @@ case class MapGroups(
|
|||
serializer: Seq[NamedExpression],
|
||||
groupingAttributes: Seq[Attribute],
|
||||
dataAttributes: Seq[Attribute],
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator {
|
||||
|
||||
override def deserializers: Seq[(Expression, Seq[Attribute])] =
|
||||
Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
|
||||
}
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator
|
||||
|
||||
/** Factory for constructing new `CoGroup` nodes. */
|
||||
object CoGroup {
|
||||
|
@ -170,22 +156,24 @@ object CoGroup {
|
|||
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
|
||||
leftGroup: Seq[Attribute],
|
||||
rightGroup: Seq[Attribute],
|
||||
leftData: Seq[Attribute],
|
||||
rightData: Seq[Attribute],
|
||||
leftAttr: Seq[Attribute],
|
||||
rightAttr: Seq[Attribute],
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan): CoGroup = {
|
||||
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
|
||||
|
||||
CoGroup(
|
||||
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
|
||||
encoderFor[Key].deserializer,
|
||||
encoderFor[Left].deserializer,
|
||||
encoderFor[Right].deserializer,
|
||||
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
|
||||
// resolve the `keyDeserializer` based on either of them, here we pick the left one.
|
||||
UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
|
||||
UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
|
||||
UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
|
||||
encoderFor[Result].namedExpressions,
|
||||
leftGroup,
|
||||
rightGroup,
|
||||
leftData,
|
||||
rightData,
|
||||
leftAttr,
|
||||
rightAttr,
|
||||
left,
|
||||
right)
|
||||
}
|
||||
|
@ -206,10 +194,4 @@ case class CoGroup(
|
|||
leftAttr: Seq[Attribute],
|
||||
rightAttr: Seq[Attribute],
|
||||
left: LogicalPlan,
|
||||
right: LogicalPlan) extends BinaryNode with ObjectOperator {
|
||||
|
||||
override def deserializers: Seq[(Expression, Seq[Attribute])] =
|
||||
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
|
||||
// the `keyDeserializer` based on either of them, here we pick the left one.
|
||||
Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
|
||||
}
|
||||
right: LogicalPlan) extends BinaryNode with ObjectOperator
|
||||
|
|
Loading…
Reference in a new issue