[SPARK-5817] [SQL] Fix bug of udtf with column names

It's a bug while do query like:
```sql
select d from (select explode(array(1,1)) d from src limit 1) t
```
And it will throws exception like:
```
org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
at scala.collection.AbstractTraversable.map(Traversable.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
```

To solve the bug, it requires code refactoring for UDTF
The major changes are about:
* Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly.
* UDTF will be asked for the output schema (data types) during the logical plan analyzing.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #4602 from chenghao-intel/explode_bug and squashes the following commits:

c2a5132 [Cheng Hao] add back resolved for Alias
556e982 [Cheng Hao] revert the unncessary change
002c361 [Cheng Hao] change the rule of resolved for Generate
04ae500 [Cheng Hao] add qualifier only for generator output
5ee5d2c [Cheng Hao] prepend the new qualifier
d2e8b43 [Cheng Hao] Update the code as feedback
ca5e7f4 [Cheng Hao] shrink the commits
This commit is contained in:
Cheng Hao 2015-04-21 15:11:15 -07:00 committed by Michael Armbrust
parent 2a24bf92e6
commit 7662ec23bb
26 changed files with 207 additions and 145 deletions

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@ -59,6 +58,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
@ -474,8 +474,59 @@ class Analyzer(
*/
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(Seq(Alias(g: Generator, _)), child) =>
Generate(g, join = false, outer = false, None, child)
case Project(Seq(Alias(g: Generator, name)), child) =>
Generate(g, join = false, outer = false,
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
Generate(g, join = false, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), child)
}
}
/**
* Resolve the Generate, if the output names specified, we will take them, otherwise
* we will try to provide the default names, which follow the same rule with Hive.
*/
object ResolveGenerate extends Rule[LogicalPlan] {
// Construct the output attributes for the generator,
// The output attribute names can be either specified or
// auto generated.
private def makeGeneratorOutput(
generator: Generator,
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
val elementTypes = generator.elementTypes
if (generatorOutput.length == elementTypes.length) {
generatorOutput.zip(elementTypes).map {
case (a, (t, nullable)) if !a.resolved =>
AttributeReference(a.name, t, nullable)()
case (a, _) => a
}
} else if (generatorOutput.length == 0) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
throw new AnalysisException(
s"""
|The number of aliases supplied in the AS clause does not match
|the number of columns output by the UDTF expected
|${elementTypes.size} aliases but got ${generatorOutput.size}
""".stripMargin)
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case p: Generate if p.resolved == false =>
// if the generator output names are not specified, we will use the default ones.
Generate(
p.generator,
join = p.join,
outer = p.outer,
p.qualifier,
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}
}

View file

@ -38,6 +38,12 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).length >= 1
}
def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
@ -110,6 +116,12 @@ trait CheckAnalysis {
failAnalysis(
s"unresolved operator ${operator.simpleString}")
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
case _ => // Analysis successful!
}
}

View file

@ -284,12 +284,13 @@ package object dsl {
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
// TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(

View file

@ -42,47 +42,30 @@ abstract class Generator extends Expression {
override type EvaluatedType = TraversableOnce[Row]
override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
// TODO ideally we should return the type of ArrayType(StructType),
// however, we don't keep the output field names in the Generator.
override def dataType: DataType = throw new UnsupportedOperationException
override def nullable: Boolean = false
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
* The output element data types in structure of Seq[(DataType, Nullable)]
* TODO we probably need to add more information like metadata etc.
*/
protected def makeOutput(): Seq[Attribute]
private var _output: Seq[Attribute] = null
def output: Seq[Attribute] = {
if (_output == null) {
_output = makeOutput()
}
_output
}
def elementTypes: Seq[(DataType, Boolean)]
/** Should be implemented by child classes to perform specific Generators. */
override def eval(input: Row): TraversableOnce[Row]
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
val copy = super.makeCopy(newArgs)
copy._output = _output
copy
}
}
/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{
override protected def makeOutput(): Seq[Attribute] = schema
extends Generator {
override def eval(input: Row): TraversableOnce[Row] = {
// TODO(davies): improve this
@ -98,30 +81,18 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
case class Explode(attributeNames: Seq[String], child: Expression)
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
private lazy val elementTypes = child.dataType match {
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
}
// TODO: Move this pattern into Generator.
protected def makeOutput() =
if (attributeNames.size == elementTypes.size) {
attributeNames.zip(elementTypes).map {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
}
} else {
elementTypes.zipWithIndex.map {
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
}
}
override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_, _) =>

View file

@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {
override type EvaluatedType = Any
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
override def eval(input: Row): Any = child.eval(input)

View file

@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition,
generate @ Generate(generator, join, outer, alias, grandChild)) =>
case filter @ Filter(condition, g: Generate) =>
// Predicates that reference attributes produced by the `Generate` operator cannot
// be pushed below the operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
conjunct => conjunct.references subsetOf grandChild.outputSet
conjunct => conjunct.references subsetOf g.child.outputSet
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter

View file

@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param alias when set, this string is applied to the schema of the output of the transformation
* as a qualifier.
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param generatorOutput The output schema of the Generator.
* @param child Children logical plan node
*/
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
alias: Option[String],
qualifier: Option[String],
generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {
protected def generatorOutput: Seq[Attribute] = {
val output = alias
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
.getOrElse(generator.output)
if (join && outer) {
output.map(_.withNullability(true))
} else {
output
}
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
generator.elementTypes.length == generatorOutput.length &&
!generatorOutput.exists(!_.resolved)
}
override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
// we don't want the gOutput to be taken as part of the expressions
// as that will cause exceptions like unresolved attributes etc.
override def expressions: Seq[Expression] = generator :: Nil
def output: Seq[Attribute] = {
val qualified = qualifier.map(q =>
// prepend the new qualifier to the existed one
generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
).getOrElse(generatorOutput)
if (join) child.output ++ qualified else qualified
}
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {

View file

@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)

View file

@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
.generate(Explode('c_arr), true, false, Some("arr")).analyze
}
comparePlans(optimized, correctAnswer)
}
test("generate: part of conjuncts referenced generated column") {
val generator = Explode(Seq("c"), 'c_arr)
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, true, false, Some("arr"))
@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
}
val optimized = Optimize(originalQuery)

View file

@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
@ -711,12 +711,16 @@ class DataFrame private[sql](
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes
val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
val names = schema.toAttributes.map(_.name)
val rowFunction =
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/**
@ -733,12 +737,17 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
// TODO handle the metadata?
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
val names = attributes.map(_.name)
def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/////////////////////////////////////////////////////////////////////////////

View file

@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param output the output attributes of this node, which constructed in analysis phase,
* and we can not change it, as the parent node bound with it already.
*/
@DeveloperApi
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
// This must be a val since the generator output expr ids are not preserved by serialization.
protected val generatorOutput: Seq[Attribute] = {
if (join && outer) {
generator.output.map(_.withNullability(true))
} else {
generator.output
}
}
// This must be a val since the generator output expr ids are not preserved by serialization.
override val output =
if (join) child.output ++ generatorOutput else generatorOutput
val boundGenerator = BindReferences.bindReference(generator, child.output)
override def execute(): RDD[Row] = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
newProjection(child.output ++ nullValues, child.output)
val joinProjection =
newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
val joinProjection = newProjection(output, output)
val joinedRow = new JoinedRow
iter.flatMap {row =>

View file

@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.Generate(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>

View file

@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.CreateTables ::
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
ResolveUdtfsAlias ::
sources.PreInsertCastAndRename ::
Nil
}

View file

@ -725,12 +725,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val alias =
getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
Generate(
nodesToGenerator(clauses),
join = true,
outer = false,
Some(alias.toLowerCase),
withWhere)
val (generator, attributes) = nodesToGenerator(clauses)
Generate(
generator,
join = true,
outer = false,
Some(alias.toLowerCase),
attributes.map(UnresolvedAttribute(_)),
withWhere)
}.getOrElse(withWhere)
// The projection of the query can either be a normal projection, an aggregation
@ -833,12 +835,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
Generate(
nodesToGenerator(clauses),
join = true,
outer = isOuter.nonEmpty,
Some(alias.toLowerCase),
nodeToRelation(relationClause))
val (generator, attributes) = nodesToGenerator(clauses)
Generate(
generator,
join = true,
outer = isOuter.nonEmpty,
Some(alias.toLowerCase),
attributes.map(UnresolvedAttribute(_)),
nodeToRelation(relationClause))
/* All relations, possibly with aliases or sampling clauses. */
case Token("TOK_TABREF", clauses) =>
@ -1311,7 +1315,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val explode = "(?i)explode".r
def nodesToGenerator(nodes: Seq[Node]): Generator = {
def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = {
val function = nodes.head
val attributes = nodes.flatMap {
@ -1321,7 +1325,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
function match {
case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
Explode(attributes, nodeToExpr(child))
(Explode(nodeToExpr(child)), attributes)
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
val functionInfo: FunctionInfo =
@ -1329,10 +1333,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName
HiveGenericUdtf(
(HiveGenericUdtf(
new HiveFunctionWrapper(functionClassName),
attributes,
children.map(nodeToExpr))
children.map(nodeToExpr)), attributes)
case a: ASTNode =>
throw new NotImplementedError(

View file

@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
@ -266,7 +266,6 @@ private[hive] case class HiveUdaf(
*/
private[hive] case class HiveGenericUdtf(
funcWrapper: HiveFunctionWrapper,
aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors {
@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf(
@transient
protected lazy val udtInput = new Array[AnyRef](children.length)
protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map {
field => inspectorToDataType(field.getFieldObjectInspector)
}
override protected def makeOutput() = {
// Use column names when given, otherwise _c1, _c2, ... _cn.
if (aliasNames.size == outputDataTypes.size) {
aliasNames.zip(outputDataTypes).map {
case (attrName, attrDataType) =>
AttributeReference(attrName, attrDataType, nullable = true)()
}
} else {
outputDataTypes.zipWithIndex.map {
case (attrDataType, i) =>
AttributeReference(s"_c$i", attrDataType, nullable = true)()
}
}
lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
field => (inspectorToDataType(field.getFieldObjectInspector), true)
}
override def eval(input: Row): TraversableOnce[Row] = {
@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf(
}
}
/**
* Resolve Udtfs Alias.
*/
private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p @ Project(projectList, _)
if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 =>
throw new TreeNodeException(p, "only single Generator supported for SELECT clause")
case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) =>
Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child)
case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) =>
Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child)
}
}
private[hive] case class HiveUdafFunction(
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],

View file

@ -27,7 +27,7 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive._
@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
createQueryTest("insert table with generator with column name",
"""
| CREATE TABLE gen_tmp (key Int);
| INSERT OVERWRITE TABLE gen_tmp
| SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3;
| SELECT key FROM gen_tmp ORDER BY key ASC;
""".stripMargin)
createQueryTest("insert table with generator with multiple column names",
"""
| CREATE TABLE gen_tmp (key Int, value String);
| INSERT OVERWRITE TABLE gen_tmp
| SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3;
| SELECT key, value FROM gen_tmp ORDER BY key, value ASC;
""".stripMargin)
createQueryTest("insert table with generator without column name",
"""
| CREATE TABLE gen_tmp (key Int);
| INSERT OVERWRITE TABLE gen_tmp
| SELECT explode(array(1,2,3)) FROM src LIMIT 3;
| SELECT key FROM gen_tmp ORDER BY key ASC;
""".stripMargin)
test("multiple generator in projection") {
intercept[AnalysisException] {
sql("SELECT explode(map(key, value)), key FROM src").collect()
}
intercept[AnalysisException] {
sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect()
}
}
createQueryTest("! operator",
"""
|SELECT a FROM (
@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("lateral view2",
"SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl")
createQueryTest("lateral view3",
"FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX")
@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("lateral view6",
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
createQueryTest("Specify the udtf output",
"SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t")
test("sampling") {
sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s")