[SPARK-5817] [SQL] Fix bug of udtf with column names
It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao <hao.cheng@intel.com> Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits
This commit is contained in:
parent
2a24bf92e6
commit
7662ec23bb
|
@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
|
|||
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
|
@ -59,6 +58,7 @@ class Analyzer(
|
|||
ResolveReferences ::
|
||||
ResolveGroupingAnalytics ::
|
||||
ResolveSortReferences ::
|
||||
ResolveGenerate ::
|
||||
ImplicitGenerate ::
|
||||
ResolveFunctions ::
|
||||
GlobalAggregates ::
|
||||
|
@ -474,8 +474,59 @@ class Analyzer(
|
|||
*/
|
||||
object ImplicitGenerate extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case Project(Seq(Alias(g: Generator, _)), child) =>
|
||||
Generate(g, join = false, outer = false, None, child)
|
||||
case Project(Seq(Alias(g: Generator, name)), child) =>
|
||||
Generate(g, join = false, outer = false,
|
||||
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
|
||||
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
|
||||
Generate(g, join = false, outer = false,
|
||||
qualifier = None, names.map(UnresolvedAttribute(_)), child)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the Generate, if the output names specified, we will take them, otherwise
|
||||
* we will try to provide the default names, which follow the same rule with Hive.
|
||||
*/
|
||||
object ResolveGenerate extends Rule[LogicalPlan] {
|
||||
// Construct the output attributes for the generator,
|
||||
// The output attribute names can be either specified or
|
||||
// auto generated.
|
||||
private def makeGeneratorOutput(
|
||||
generator: Generator,
|
||||
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
|
||||
val elementTypes = generator.elementTypes
|
||||
|
||||
if (generatorOutput.length == elementTypes.length) {
|
||||
generatorOutput.zip(elementTypes).map {
|
||||
case (a, (t, nullable)) if !a.resolved =>
|
||||
AttributeReference(a.name, t, nullable)()
|
||||
case (a, _) => a
|
||||
}
|
||||
} else if (generatorOutput.length == 0) {
|
||||
elementTypes.zipWithIndex.map {
|
||||
// keep the default column names as Hive does _c0, _c1, _cN
|
||||
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
|
||||
}
|
||||
} else {
|
||||
throw new AnalysisException(
|
||||
s"""
|
||||
|The number of aliases supplied in the AS clause does not match
|
||||
|the number of columns output by the UDTF expected
|
||||
|${elementTypes.size} aliases but got ${generatorOutput.size}
|
||||
""".stripMargin)
|
||||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case p: Generate if !p.child.resolved || !p.generator.resolved => p
|
||||
case p: Generate if p.resolved == false =>
|
||||
// if the generator output names are not specified, we will use the default ones.
|
||||
Generate(
|
||||
p.generator,
|
||||
join = p.join,
|
||||
outer = p.outer,
|
||||
p.qualifier,
|
||||
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,12 @@ trait CheckAnalysis {
|
|||
throw new AnalysisException(msg)
|
||||
}
|
||||
|
||||
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
|
||||
exprs.flatMap(_.collect {
|
||||
case e: Generator => true
|
||||
}).length >= 1
|
||||
}
|
||||
|
||||
def checkAnalysis(plan: LogicalPlan): Unit = {
|
||||
// We transform up and order the rules so as to catch the first possible failure instead
|
||||
// of the result of cascading resolution failures.
|
||||
|
@ -110,6 +116,12 @@ trait CheckAnalysis {
|
|||
failAnalysis(
|
||||
s"unresolved operator ${operator.simpleString}")
|
||||
|
||||
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
|
||||
failAnalysis(
|
||||
s"""Only a single table generating function is allowed in a SELECT clause, found:
|
||||
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
|
||||
|
||||
|
||||
case _ => // Analysis successful!
|
||||
}
|
||||
}
|
||||
|
|
|
@ -284,12 +284,13 @@ package object dsl {
|
|||
seed: Int = (math.random * 1000).toInt): LogicalPlan =
|
||||
Sample(fraction, withReplacement, seed, logicalPlan)
|
||||
|
||||
// TODO specify the output column names
|
||||
def generate(
|
||||
generator: Generator,
|
||||
join: Boolean = false,
|
||||
outer: Boolean = false,
|
||||
alias: Option[String] = None): LogicalPlan =
|
||||
Generate(generator, join, outer, None, logicalPlan)
|
||||
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
|
||||
|
||||
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
|
||||
InsertIntoTable(
|
||||
|
|
|
@ -42,47 +42,30 @@ abstract class Generator extends Expression {
|
|||
|
||||
override type EvaluatedType = TraversableOnce[Row]
|
||||
|
||||
override lazy val dataType =
|
||||
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
|
||||
// TODO ideally we should return the type of ArrayType(StructType),
|
||||
// however, we don't keep the output field names in the Generator.
|
||||
override def dataType: DataType = throw new UnsupportedOperationException
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
/**
|
||||
* Should be overridden by specific generators. Called only once for each instance to ensure
|
||||
* that rule application does not change the output schema of a generator.
|
||||
* The output element data types in structure of Seq[(DataType, Nullable)]
|
||||
* TODO we probably need to add more information like metadata etc.
|
||||
*/
|
||||
protected def makeOutput(): Seq[Attribute]
|
||||
|
||||
private var _output: Seq[Attribute] = null
|
||||
|
||||
def output: Seq[Attribute] = {
|
||||
if (_output == null) {
|
||||
_output = makeOutput()
|
||||
}
|
||||
_output
|
||||
}
|
||||
def elementTypes: Seq[(DataType, Boolean)]
|
||||
|
||||
/** Should be implemented by child classes to perform specific Generators. */
|
||||
override def eval(input: Row): TraversableOnce[Row]
|
||||
|
||||
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
|
||||
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
|
||||
val copy = super.makeCopy(newArgs)
|
||||
copy._output = _output
|
||||
copy
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A generator that produces its output using the provided lambda function.
|
||||
*/
|
||||
case class UserDefinedGenerator(
|
||||
schema: Seq[Attribute],
|
||||
elementTypes: Seq[(DataType, Boolean)],
|
||||
function: Row => TraversableOnce[Row],
|
||||
children: Seq[Expression])
|
||||
extends Generator{
|
||||
|
||||
override protected def makeOutput(): Seq[Attribute] = schema
|
||||
extends Generator {
|
||||
|
||||
override def eval(input: Row): TraversableOnce[Row] = {
|
||||
// TODO(davies): improve this
|
||||
|
@ -98,30 +81,18 @@ case class UserDefinedGenerator(
|
|||
/**
|
||||
* Given an input array produces a sequence of rows for each value in the array.
|
||||
*/
|
||||
case class Explode(attributeNames: Seq[String], child: Expression)
|
||||
case class Explode(child: Expression)
|
||||
extends Generator with trees.UnaryNode[Expression] {
|
||||
|
||||
override lazy val resolved =
|
||||
child.resolved &&
|
||||
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
|
||||
|
||||
private lazy val elementTypes = child.dataType match {
|
||||
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
|
||||
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
|
||||
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
|
||||
}
|
||||
|
||||
// TODO: Move this pattern into Generator.
|
||||
protected def makeOutput() =
|
||||
if (attributeNames.size == elementTypes.size) {
|
||||
attributeNames.zip(elementTypes).map {
|
||||
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
|
||||
}
|
||||
} else {
|
||||
elementTypes.zipWithIndex.map {
|
||||
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: Row): TraversableOnce[Row] = {
|
||||
child.dataType match {
|
||||
case ArrayType(_, _) =>
|
||||
|
|
|
@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
|
|||
extends NamedExpression with trees.UnaryNode[Expression] {
|
||||
|
||||
override type EvaluatedType = Any
|
||||
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
|
||||
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
|
||||
|
||||
override def eval(input: Row): Any = child.eval(input)
|
||||
|
||||
|
|
|
@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
|
|||
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case filter @ Filter(condition,
|
||||
generate @ Generate(generator, join, outer, alias, grandChild)) =>
|
||||
case filter @ Filter(condition, g: Generate) =>
|
||||
// Predicates that reference attributes produced by the `Generate` operator cannot
|
||||
// be pushed below the operator.
|
||||
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
|
||||
conjunct => conjunct.references subsetOf grandChild.outputSet
|
||||
conjunct => conjunct.references subsetOf g.child.outputSet
|
||||
}
|
||||
if (pushDown.nonEmpty) {
|
||||
val pushDownPredicate = pushDown.reduce(And)
|
||||
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
|
||||
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
|
||||
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
|
||||
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
|
||||
} else {
|
||||
filter
|
||||
|
|
|
@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
|
|||
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
|
||||
* programming with one important additional feature, which allows the input rows to be joined with
|
||||
* their output.
|
||||
* @param generator the generator expression
|
||||
* @param join when true, each output row is implicitly joined with the input tuple that produced
|
||||
* it.
|
||||
* @param outer when true, each input row will be output at least once, even if the output of the
|
||||
* given `generator` is empty. `outer` has no effect when `join` is false.
|
||||
* @param alias when set, this string is applied to the schema of the output of the transformation
|
||||
* as a qualifier.
|
||||
* @param qualifier Qualifier for the attributes of generator(UDTF)
|
||||
* @param generatorOutput The output schema of the Generator.
|
||||
* @param child Children logical plan node
|
||||
*/
|
||||
case class Generate(
|
||||
generator: Generator,
|
||||
join: Boolean,
|
||||
outer: Boolean,
|
||||
alias: Option[String],
|
||||
qualifier: Option[String],
|
||||
generatorOutput: Seq[Attribute],
|
||||
child: LogicalPlan)
|
||||
extends UnaryNode {
|
||||
|
||||
protected def generatorOutput: Seq[Attribute] = {
|
||||
val output = alias
|
||||
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
|
||||
.getOrElse(generator.output)
|
||||
if (join && outer) {
|
||||
output.map(_.withNullability(true))
|
||||
} else {
|
||||
output
|
||||
}
|
||||
override lazy val resolved: Boolean = {
|
||||
generator.resolved &&
|
||||
childrenResolved &&
|
||||
generator.elementTypes.length == generatorOutput.length &&
|
||||
!generatorOutput.exists(!_.resolved)
|
||||
}
|
||||
|
||||
override def output: Seq[Attribute] =
|
||||
if (join) child.output ++ generatorOutput else generatorOutput
|
||||
// we don't want the gOutput to be taken as part of the expressions
|
||||
// as that will cause exceptions like unresolved attributes etc.
|
||||
override def expressions: Seq[Expression] = generator :: Nil
|
||||
|
||||
def output: Seq[Attribute] = {
|
||||
val qualified = qualifier.map(q =>
|
||||
// prepend the new qualifier to the existed one
|
||||
generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
|
||||
).getOrElse(generatorOutput)
|
||||
|
||||
if (join) child.output ++ qualified else qualified
|
||||
}
|
||||
}
|
||||
|
||||
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
|
||||
|
|
|
@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
|
|||
|
||||
assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
|
||||
|
||||
val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
|
||||
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
|
||||
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
|
||||
|
||||
assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
|
||||
|
|
|
@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
|
|||
test("generate: predicate referenced no generated column") {
|
||||
val originalQuery = {
|
||||
testRelationWithArrayType
|
||||
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
|
||||
.generate(Explode('c_arr), true, false, Some("arr"))
|
||||
.where(('b >= 5) && ('a > 6))
|
||||
}
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
val correctAnswer = {
|
||||
testRelationWithArrayType
|
||||
.where(('b >= 5) && ('a > 6))
|
||||
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
|
||||
.generate(Explode('c_arr), true, false, Some("arr")).analyze
|
||||
}
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("generate: part of conjuncts referenced generated column") {
|
||||
val generator = Explode(Seq("c"), 'c_arr)
|
||||
val generator = Explode('c_arr)
|
||||
val originalQuery = {
|
||||
testRelationWithArrayType
|
||||
.generate(generator, true, false, Some("arr"))
|
||||
|
@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
|
|||
test("generate: all conjuncts referenced generated column") {
|
||||
val originalQuery = {
|
||||
testRelationWithArrayType
|
||||
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
|
||||
.generate(Explode('c_arr), true, false, Some("arr"))
|
||||
.where(('c > 6) || ('b > 5)).analyze
|
||||
}
|
||||
val optimized = Optimize(originalQuery)
|
||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
|
@ -711,12 +711,16 @@ class DataFrame private[sql](
|
|||
*/
|
||||
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
|
||||
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
|
||||
val attributes = schema.toAttributes
|
||||
|
||||
val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
|
||||
val names = schema.toAttributes.map(_.name)
|
||||
|
||||
val rowFunction =
|
||||
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
|
||||
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
|
||||
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
|
||||
|
||||
Generate(generator, join = true, outer = false, None, logicalPlan)
|
||||
Generate(generator, join = true, outer = false,
|
||||
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -733,12 +737,17 @@ class DataFrame private[sql](
|
|||
: DataFrame = {
|
||||
val dataType = ScalaReflection.schemaFor[B].dataType
|
||||
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
|
||||
// TODO handle the metadata?
|
||||
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
|
||||
val names = attributes.map(_.name)
|
||||
|
||||
def rowFunction(row: Row): TraversableOnce[Row] = {
|
||||
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
|
||||
}
|
||||
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
|
||||
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
|
||||
|
||||
Generate(generator, join = true, outer = false, None, logicalPlan)
|
||||
Generate(generator, join = true, outer = false,
|
||||
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
|
||||
* programming with one important additional feature, which allows the input rows to be joined with
|
||||
* their output.
|
||||
* @param generator the generator expression
|
||||
* @param join when true, each output row is implicitly joined with the input tuple that produced
|
||||
* it.
|
||||
* @param outer when true, each input row will be output at least once, even if the output of the
|
||||
* given `generator` is empty. `outer` has no effect when `join` is false.
|
||||
* @param output the output attributes of this node, which constructed in analysis phase,
|
||||
* and we can not change it, as the parent node bound with it already.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class Generate(
|
||||
generator: Generator,
|
||||
join: Boolean,
|
||||
outer: Boolean,
|
||||
output: Seq[Attribute],
|
||||
child: SparkPlan)
|
||||
extends UnaryNode {
|
||||
|
||||
// This must be a val since the generator output expr ids are not preserved by serialization.
|
||||
protected val generatorOutput: Seq[Attribute] = {
|
||||
if (join && outer) {
|
||||
generator.output.map(_.withNullability(true))
|
||||
} else {
|
||||
generator.output
|
||||
}
|
||||
}
|
||||
|
||||
// This must be a val since the generator output expr ids are not preserved by serialization.
|
||||
override val output =
|
||||
if (join) child.output ++ generatorOutput else generatorOutput
|
||||
|
||||
val boundGenerator = BindReferences.bindReference(generator, child.output)
|
||||
|
||||
override def execute(): RDD[Row] = {
|
||||
if (join) {
|
||||
child.execute().mapPartitions { iter =>
|
||||
val nullValues = Seq.fill(generator.output.size)(Literal(null))
|
||||
val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
|
||||
// Used to produce rows with no matches when outer = true.
|
||||
val outerProjection =
|
||||
newProjection(child.output ++ nullValues, child.output)
|
||||
|
||||
val joinProjection =
|
||||
newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
|
||||
val joinProjection = newProjection(output, output)
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
iter.flatMap {row =>
|
||||
|
|
|
@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
execution.Except(planLater(left), planLater(right)) :: Nil
|
||||
case logical.Intersect(left, right) =>
|
||||
execution.Intersect(planLater(left), planLater(right)) :: Nil
|
||||
case logical.Generate(generator, join, outer, _, child) =>
|
||||
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
|
||||
case g @ logical.Generate(generator, join, outer, _, _, child) =>
|
||||
execution.Generate(
|
||||
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
|
||||
case logical.OneRowRelation =>
|
||||
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
|
||||
case logical.Repartition(expressions, child) =>
|
||||
|
|
|
@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
catalog.CreateTables ::
|
||||
catalog.PreInsertionCasts ::
|
||||
ExtractPythonUdfs ::
|
||||
ResolveUdtfsAlias ::
|
||||
sources.PreInsertCastAndRename ::
|
||||
Nil
|
||||
}
|
||||
|
|
|
@ -725,12 +725,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
val alias =
|
||||
getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
|
||||
|
||||
Generate(
|
||||
nodesToGenerator(clauses),
|
||||
join = true,
|
||||
outer = false,
|
||||
Some(alias.toLowerCase),
|
||||
withWhere)
|
||||
val (generator, attributes) = nodesToGenerator(clauses)
|
||||
Generate(
|
||||
generator,
|
||||
join = true,
|
||||
outer = false,
|
||||
Some(alias.toLowerCase),
|
||||
attributes.map(UnresolvedAttribute(_)),
|
||||
withWhere)
|
||||
}.getOrElse(withWhere)
|
||||
|
||||
// The projection of the query can either be a normal projection, an aggregation
|
||||
|
@ -833,12 +835,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
|
||||
val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
|
||||
|
||||
Generate(
|
||||
nodesToGenerator(clauses),
|
||||
join = true,
|
||||
outer = isOuter.nonEmpty,
|
||||
Some(alias.toLowerCase),
|
||||
nodeToRelation(relationClause))
|
||||
val (generator, attributes) = nodesToGenerator(clauses)
|
||||
Generate(
|
||||
generator,
|
||||
join = true,
|
||||
outer = isOuter.nonEmpty,
|
||||
Some(alias.toLowerCase),
|
||||
attributes.map(UnresolvedAttribute(_)),
|
||||
nodeToRelation(relationClause))
|
||||
|
||||
/* All relations, possibly with aliases or sampling clauses. */
|
||||
case Token("TOK_TABREF", clauses) =>
|
||||
|
@ -1311,7 +1315,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
|
||||
|
||||
val explode = "(?i)explode".r
|
||||
def nodesToGenerator(nodes: Seq[Node]): Generator = {
|
||||
def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = {
|
||||
val function = nodes.head
|
||||
|
||||
val attributes = nodes.flatMap {
|
||||
|
@ -1321,7 +1325,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
|
||||
function match {
|
||||
case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
|
||||
Explode(attributes, nodeToExpr(child))
|
||||
(Explode(nodeToExpr(child)), attributes)
|
||||
|
||||
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
|
||||
val functionInfo: FunctionInfo =
|
||||
|
@ -1329,10 +1333,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
|
|||
sys.error(s"Couldn't find function $functionName"))
|
||||
val functionClassName = functionInfo.getFunctionClass.getName
|
||||
|
||||
HiveGenericUdtf(
|
||||
(HiveGenericUdtf(
|
||||
new HiveFunctionWrapper(functionClassName),
|
||||
attributes,
|
||||
children.map(nodeToExpr))
|
||||
children.map(nodeToExpr)), attributes)
|
||||
|
||||
case a: ASTNode =>
|
||||
throw new NotImplementedError(
|
||||
|
|
|
@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry
|
|||
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
|
||||
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
|
||||
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
|
||||
} else {
|
||||
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
|
||||
}
|
||||
|
@ -266,7 +266,6 @@ private[hive] case class HiveUdaf(
|
|||
*/
|
||||
private[hive] case class HiveGenericUdtf(
|
||||
funcWrapper: HiveFunctionWrapper,
|
||||
aliasNames: Seq[String],
|
||||
children: Seq[Expression])
|
||||
extends Generator with HiveInspectors {
|
||||
|
||||
|
@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf(
|
|||
@transient
|
||||
protected lazy val udtInput = new Array[AnyRef](children.length)
|
||||
|
||||
protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map {
|
||||
field => inspectorToDataType(field.getFieldObjectInspector)
|
||||
}
|
||||
|
||||
override protected def makeOutput() = {
|
||||
// Use column names when given, otherwise _c1, _c2, ... _cn.
|
||||
if (aliasNames.size == outputDataTypes.size) {
|
||||
aliasNames.zip(outputDataTypes).map {
|
||||
case (attrName, attrDataType) =>
|
||||
AttributeReference(attrName, attrDataType, nullable = true)()
|
||||
}
|
||||
} else {
|
||||
outputDataTypes.zipWithIndex.map {
|
||||
case (attrDataType, i) =>
|
||||
AttributeReference(s"_c$i", attrDataType, nullable = true)()
|
||||
}
|
||||
}
|
||||
lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
|
||||
field => (inspectorToDataType(field.getFieldObjectInspector), true)
|
||||
}
|
||||
|
||||
override def eval(input: Row): TraversableOnce[Row] = {
|
||||
|
@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve Udtfs Alias.
|
||||
*/
|
||||
private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case p @ Project(projectList, _)
|
||||
if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 =>
|
||||
throw new TreeNodeException(p, "only single Generator supported for SELECT clause")
|
||||
|
||||
case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) =>
|
||||
Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child)
|
||||
case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) =>
|
||||
Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child)
|
||||
}
|
||||
}
|
||||
|
||||
private[hive] case class HiveUdafFunction(
|
||||
funcWrapper: HiveFunctionWrapper,
|
||||
exprs: Seq[Expression],
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
1
|
|
@ -0,0 +1,3 @@
|
|||
1
|
||||
2
|
||||
3
|
|
@ -0,0 +1,3 @@
|
|||
86 val_86
|
||||
238 val_238
|
||||
311 val_311
|
|
@ -0,0 +1,3 @@
|
|||
1
|
||||
2
|
||||
3
|
|
@ -27,7 +27,7 @@ import scala.util.Try
|
|||
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
|
||||
|
||||
import org.apache.spark.{SparkFiles, SparkException}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Project
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.hive._
|
||||
|
@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
}
|
||||
}
|
||||
|
||||
createQueryTest("insert table with generator with column name",
|
||||
"""
|
||||
| CREATE TABLE gen_tmp (key Int);
|
||||
| INSERT OVERWRITE TABLE gen_tmp
|
||||
| SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3;
|
||||
| SELECT key FROM gen_tmp ORDER BY key ASC;
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("insert table with generator with multiple column names",
|
||||
"""
|
||||
| CREATE TABLE gen_tmp (key Int, value String);
|
||||
| INSERT OVERWRITE TABLE gen_tmp
|
||||
| SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3;
|
||||
| SELECT key, value FROM gen_tmp ORDER BY key, value ASC;
|
||||
""".stripMargin)
|
||||
|
||||
createQueryTest("insert table with generator without column name",
|
||||
"""
|
||||
| CREATE TABLE gen_tmp (key Int);
|
||||
| INSERT OVERWRITE TABLE gen_tmp
|
||||
| SELECT explode(array(1,2,3)) FROM src LIMIT 3;
|
||||
| SELECT key FROM gen_tmp ORDER BY key ASC;
|
||||
""".stripMargin)
|
||||
|
||||
test("multiple generator in projection") {
|
||||
intercept[AnalysisException] {
|
||||
sql("SELECT explode(map(key, value)), key FROM src").collect()
|
||||
}
|
||||
|
||||
intercept[AnalysisException] {
|
||||
sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect()
|
||||
}
|
||||
}
|
||||
|
||||
createQueryTest("! operator",
|
||||
"""
|
||||
|SELECT a FROM (
|
||||
|
@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
createQueryTest("lateral view2",
|
||||
"SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl")
|
||||
|
||||
|
||||
createQueryTest("lateral view3",
|
||||
"FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX")
|
||||
|
||||
|
@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
createQueryTest("lateral view6",
|
||||
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
|
||||
|
||||
createQueryTest("Specify the udtf output",
|
||||
"SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t")
|
||||
|
||||
test("sampling") {
|
||||
sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
|
||||
sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s")
|
||||
|
|
Loading…
Reference in a new issue