[SPARK-12593][SQL] Converts resolved logical plan back to SQL
This PR tries to enable Spark SQL to convert resolved logical plans back to SQL query strings. For now, the major use case is to canonicalize Spark SQL native view support. The major entry point is `SQLBuilder.toSQL`, which returns an `Option[String]` if the logical plan is recognized. The current version is still in WIP status, and is quite limited. Known limitations include: 1. The logical plan must be analyzed but not optimized The optimizer erases `Subquery` operators, which contain necessary scope information for SQL generation. Future versions should be able to recover erased scope information by inserting subqueries when necessary. 1. The logical plan must be created using HiveQL query string Query plans generated by composing arbitrary DataFrame API combinations are not supported yet. Operators within these query plans need to be rearranged into a canonical form that is more suitable for direct SQL generation. For example, the following query plan ``` Filter (a#1 < 10) +- MetastoreRelation default, src, None ``` need to be canonicalized into the following form before SQL generation: ``` Project [a#1, b#2, c#3] +- Filter (a#1 < 10) +- MetastoreRelation default, src, None ``` Otherwise, the SQL generation process will have to handle a large number of special cases. 1. Only a fraction of expressions and basic logical plan operators are supported in this PR Currently, 95.7% (1720 out of 1798) query plans in `HiveCompatibilitySuite` can be successfully converted to SQL query strings. Known unsupported components are: - Expressions - Part of math expressions - Part of string expressions (buggy?) - Null expressions - Calendar interval literal - Part of date time expressions - Complex type creators - Special `NOT` expressions, e.g. `NOT LIKE` and `NOT IN` - Logical plan operators/patterns - Cube, rollup, and grouping set - Script transformation - Generator - Distinct aggregation patterns that fit `DistinctAggregationRewriter` analysis rule - Window functions Support for window functions, generators, and cubes etc. will be added in follow-up PRs. This PR leverages `HiveCompatibilitySuite` for testing SQL generation in a "round-trip" manner: * For all select queries, we try to convert it back to SQL * If the query plan is convertible, we parse the generated SQL into a new logical plan * Run the new logical plan instead of the original one If the query plan is inconvertible, the test case simply falls back to the original logic. TODO - [x] Fix failed test cases - [x] Support for more basic expressions and logical plan operators (e.g. distinct aggregation etc.) - [x] Comments and documentation Author: Cheng Lian <lian@databricks.com> Closes #10541 from liancheng/sql-generation.
This commit is contained in:
parent
659fd9d04b
commit
d9447cac74
|
@ -639,7 +639,7 @@ import java.util.HashMap;
|
||||||
// counter to generate unique union aliases
|
// counter to generate unique union aliases
|
||||||
private int aliasCounter;
|
private int aliasCounter;
|
||||||
private String generateUnionAlias() {
|
private String generateUnionAlias() {
|
||||||
return "_u" + (++aliasCounter);
|
return "u_" + (++aliasCounter);
|
||||||
}
|
}
|
||||||
private char [] excludedCharForColumnName = {'.', ':'};
|
private char [] excludedCharForColumnName = {'.', ':'};
|
||||||
private boolean containExcludedCharForCreateTableColumnName(String input) {
|
private boolean containExcludedCharForCreateTableColumnName(String input) {
|
||||||
|
|
|
@ -86,8 +86,7 @@ class Analyzer(
|
||||||
HiveTypeCoercion.typeCoercionRules ++
|
HiveTypeCoercion.typeCoercionRules ++
|
||||||
extendedResolutionRules : _*),
|
extendedResolutionRules : _*),
|
||||||
Batch("Nondeterministic", Once,
|
Batch("Nondeterministic", Once,
|
||||||
PullOutNondeterministic,
|
PullOutNondeterministic),
|
||||||
ComputeCurrentTime),
|
|
||||||
Batch("UDF", Once,
|
Batch("UDF", Once,
|
||||||
HandleNullInputsForUDF),
|
HandleNullInputsForUDF),
|
||||||
Batch("Cleanup", fixedPoint,
|
Batch("Cleanup", fixedPoint,
|
||||||
|
@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the current date and time to make sure we return the same result in a single query.
|
|
||||||
*/
|
|
||||||
object ComputeCurrentTime extends Rule[LogicalPlan] {
|
|
||||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
|
||||||
val dateExpr = CurrentDate()
|
|
||||||
val timeExpr = CurrentTimestamp()
|
|
||||||
val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
|
|
||||||
val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
|
|
||||||
|
|
||||||
plan transformAllExpressions {
|
|
||||||
case CurrentDate() => currentDate
|
|
||||||
case CurrentTimestamp() => currentTime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
|
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
|
||||||
|
|
||||||
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
|
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
|
||||||
// properly qualified with this alias.
|
// properly qualified with this alias.
|
||||||
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
|
alias
|
||||||
|
.map(a => Subquery(a, tableWithQualifiers))
|
||||||
|
.getOrElse(tableWithQualifiers)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
|
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
|
||||||
|
|
|
@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
|
||||||
$evPrim = $result.copy();
|
$evPrim = $result.copy();
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = dataType match {
|
||||||
|
// HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
|
||||||
|
// type of casting can only be introduced by the analyzer, and can be omitted when converting
|
||||||
|
// back to SQL query string.
|
||||||
|
case _: ArrayType | _: MapType | _: StructType => child.sql
|
||||||
|
case _ => s"CAST(${child.sql} AS ${dataType.sql})"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -18,9 +18,10 @@
|
||||||
package org.apache.spark.sql.catalyst.expressions
|
package org.apache.spark.sql.catalyst.expressions
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
|
import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||||
|
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] {
|
||||||
protected def toCommentSafeString: String = this.toString
|
protected def toCommentSafeString: String = this.toString
|
||||||
.replace("*/", "\\*\\/")
|
.replace("*/", "\\*\\/")
|
||||||
.replace("\\u", "\\\\u")
|
.replace("\\u", "\\\\u")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns SQL representation of this expression. For expressions that don't have a SQL
|
||||||
|
* representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
|
||||||
|
*/
|
||||||
|
@throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
|
||||||
|
def sql: String = throw new UnsupportedOperationException(
|
||||||
|
s"Cannot map expression $this to its SQL representation"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression {
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"($prettyName(${child.sql}))"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression {
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
|
||||||
TypeCheckResult.TypeCheckSuccess
|
TypeCheckResult.TypeCheckSuccess
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(${left.sql} $symbol ${right.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression {
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val childrenSQL = children.map(_.sql).mkString(", ")
|
||||||
|
s"$prettyName($childrenSQL)"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
|
||||||
"org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
|
"org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = prettyName
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with
|
||||||
$countTerm++;
|
$countTerm++;
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "monotonically_increasing_id"
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName()"
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,9 +24,17 @@ import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
|
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
|
||||||
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
|
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
|
||||||
|
|
||||||
abstract sealed class SortDirection
|
abstract sealed class SortDirection {
|
||||||
case object Ascending extends SortDirection
|
def sql: String
|
||||||
case object Descending extends SortDirection
|
}
|
||||||
|
|
||||||
|
case object Ascending extends SortDirection {
|
||||||
|
override def sql: String = "ASC"
|
||||||
|
}
|
||||||
|
|
||||||
|
case object Descending extends SortDirection {
|
||||||
|
override def sql: String = "DESC"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that can be used to sort a tuple. This class extends expression primarily so that
|
* An expression that can be used to sort a tuple. This class extends expression primarily so that
|
||||||
|
|
|
@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
|
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||||
|
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
/** The mode of an [[AggregateFunction]]. */
|
/** The mode of an [[AggregateFunction]]. */
|
||||||
|
@ -93,11 +94,13 @@ private[sql] case class AggregateExpression(
|
||||||
|
|
||||||
override def prettyString: String = aggregateFunction.prettyString
|
override def prettyString: String = aggregateFunction.prettyString
|
||||||
|
|
||||||
override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
|
override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
|
||||||
|
|
||||||
|
override def sql: String = aggregateFunction.sql(isDistinct)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AggregateFunction2 is the superclass of two aggregation function interfaces:
|
* AggregateFunction is the superclass of two aggregation function interfaces:
|
||||||
*
|
*
|
||||||
* - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
|
* - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
|
||||||
* initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
|
* initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
|
||||||
|
@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
|
||||||
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
|
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
|
||||||
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
|
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def sql(isDistinct: Boolean): String = {
|
||||||
|
val distinct = if (isDistinct) "DISTINCT " else " "
|
||||||
|
s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
|
||||||
numeric.negate(input)
|
numeric.negate(input)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(-${child.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
|
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
|
||||||
|
@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
|
||||||
defineCodeGen(ctx, ev, c => c)
|
defineCodeGen(ctx, ev, c => c)
|
||||||
|
|
||||||
protected override def nullSafeEval(input: Any): Any = input
|
protected override def nullSafeEval(input: Any): Any = input
|
||||||
|
|
||||||
|
override def sql: String = s"(+${child.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
|
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${child.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class BinaryArithmetic extends BinaryOperator {
|
abstract class BinaryArithmetic extends BinaryOperator {
|
||||||
|
@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||||
val r = a % n
|
val r = a % n
|
||||||
if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
|
if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
|
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
|
||||||
|
|
||||||
|
override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
|
||||||
}
|
}
|
||||||
|
|
||||||
trait CaseWhenLike extends Expression {
|
trait CaseWhenLike extends Expression {
|
||||||
|
@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression {
|
||||||
|
|
||||||
override def nullable: Boolean = {
|
override def nullable: Boolean = {
|
||||||
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
|
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
|
||||||
thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
|
thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
|
||||||
case Seq(elseValue) => s" ELSE $elseValue"
|
case Seq(elseValue) => s" ELSE $elseValue"
|
||||||
}.mkString
|
}.mkString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val branchesSQL = branches.map(_.sql)
|
||||||
|
val (cases, maybeElse) = if (branches.length % 2 == 0) {
|
||||||
|
(branchesSQL, None)
|
||||||
|
} else {
|
||||||
|
(branchesSQL.init, Some(branchesSQL.last))
|
||||||
|
}
|
||||||
|
|
||||||
|
val head = s"CASE "
|
||||||
|
val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
|
||||||
|
val body = cases.grouped(2).map {
|
||||||
|
case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
|
||||||
|
}.mkString(" ")
|
||||||
|
|
||||||
|
head + body + tail
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// scalastyle:off
|
// scalastyle:off
|
||||||
|
@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
|
||||||
case Seq(elseValue) => s" ELSE $elseValue"
|
case Seq(elseValue) => s" ELSE $elseValue"
|
||||||
}.mkString
|
}.mkString
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val keySQL = key.sql
|
||||||
|
val branchesSQL = branches.map(_.sql)
|
||||||
|
val (cases, maybeElse) = if (branches.length % 2 == 0) {
|
||||||
|
(branchesSQL, None)
|
||||||
|
} else {
|
||||||
|
(branchesSQL.init, Some(branchesSQL.last))
|
||||||
|
}
|
||||||
|
|
||||||
|
val head = s"CASE $keySQL "
|
||||||
|
val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
|
||||||
|
val body = cases.grouped(2).map {
|
||||||
|
case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
|
||||||
|
}.mkString(" ")
|
||||||
|
|
||||||
|
head + body + tail
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback {
|
||||||
override def eval(input: InternalRow): Any = {
|
override def eval(input: InternalRow): Any = {
|
||||||
DateTimeUtils.millisToDays(System.currentTimeMillis())
|
DateTimeUtils.millisToDays(System.currentTimeMillis())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "current_date"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
|
||||||
override def eval(input: InternalRow): Any = {
|
override def eval(input: InternalRow): Any = {
|
||||||
System.currentTimeMillis() * 1000L
|
System.currentTimeMillis() * 1000L
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "current_timestamp"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression)
|
||||||
s"""${ev.value} = $sd + $d;"""
|
s"""${ev.value} = $sd + $d;"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "date_add"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression)
|
||||||
s"""${ev.value} = $sd - $d;"""
|
s"""${ev.value} = $sd - $d;"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "date_sub"
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
||||||
|
@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix
|
||||||
def this(time: Expression) = {
|
def this(time: Expression) = {
|
||||||
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
|
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "to_unix_timestamp"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi
|
||||||
def this() = {
|
def this() = {
|
||||||
this(CurrentTimestamp())
|
this(CurrentTimestamp())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "unix_timestamp"
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
|
abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
|
||||||
|
@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "unix_time"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression)
|
||||||
override def left: Expression = sec
|
override def left: Expression = sec
|
||||||
override def right: Expression = format
|
override def right: Expression = format
|
||||||
|
|
||||||
|
override def prettyName: String = "from_unixtime"
|
||||||
|
|
||||||
def this(unix: Expression) = {
|
def this(unix: Expression) = {
|
||||||
this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
|
this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
|
||||||
}
|
}
|
||||||
|
@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
|
||||||
s"""$dtu.dateAddMonths($sd, $m)"""
|
s"""$dtu.dateAddMonths($sd, $m)"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "add_months"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression)
|
||||||
s"""$dtu.monthsBetween($l, $r)"""
|
s"""$dtu.monthsBetween($l, $r)"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "months_between"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
|
||||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
defineCodeGen(ctx, ev, d => d)
|
defineCodeGen(ctx, ev, d => d)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "to_date"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
|
||||||
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
|
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
|
||||||
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
|
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
|
||||||
override def prettyName: String = "promote_precision"
|
override def prettyName: String = "promote_precision"
|
||||||
|
override def sql: String = child.sql
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = s"CheckOverflow($child, $dataType)"
|
override def toString: String = s"CheckOverflow($child, $dataType)"
|
||||||
|
|
||||||
|
override def sql: String = child.sql
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,9 +21,9 @@ import java.sql.{Date, Timestamp}
|
||||||
|
|
||||||
import org.json4s.JsonAST._
|
import org.json4s.JsonAST._
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
|
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types._
|
import org.apache.spark.unsafe.types._
|
||||||
|
|
||||||
|
@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = (value, dataType) match {
|
||||||
|
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
|
||||||
|
"NULL"
|
||||||
|
|
||||||
|
case _ if value == null =>
|
||||||
|
s"CAST(NULL AS ${dataType.sql})"
|
||||||
|
|
||||||
|
case (v: UTF8String, StringType) =>
|
||||||
|
// Escapes all backslashes and double quotes.
|
||||||
|
"\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
|
||||||
|
|
||||||
|
case (v: Byte, ByteType) =>
|
||||||
|
s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
|
||||||
|
|
||||||
|
case (v: Short, ShortType) =>
|
||||||
|
s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
|
||||||
|
|
||||||
|
case (v: Long, LongType) =>
|
||||||
|
s"CAST($v AS ${LongType.simpleString.toUpperCase})"
|
||||||
|
|
||||||
|
case (v: Float, FloatType) =>
|
||||||
|
s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
|
||||||
|
|
||||||
|
case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
|
||||||
|
s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
|
||||||
|
|
||||||
|
case (v: Int, DateType) =>
|
||||||
|
s"DATE '${DateTimeUtils.toJavaDate(v)}'"
|
||||||
|
|
||||||
|
case (v: Long, TimestampType) =>
|
||||||
|
s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
|
||||||
|
|
||||||
|
case _ => value.toString
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Specialize
|
// TODO: Specialize
|
||||||
|
|
|
@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
|
||||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
|
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$name(${child.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class UnaryLogExpression(f: Double => Double, name: String)
|
abstract class UnaryLogExpression(f: Double => Double, name: String)
|
||||||
|
|
|
@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
|
||||||
final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
|
final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def prettyName: String = "hash"
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)(
|
||||||
explicitMetadata == a.explicitMetadata
|
explicitMetadata == a.explicitMetadata
|
||||||
case _ => false
|
case _ => false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val qualifiersString =
|
||||||
|
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
|
||||||
|
s"${child.sql} AS $qualifiersString`$name`"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -271,6 +277,12 @@ case class AttributeReference(
|
||||||
// Since the expression id is not in the first constructor it is missing from the default
|
// Since the expression id is not in the first constructor it is missing from the default
|
||||||
// tree string.
|
// tree string.
|
||||||
override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
|
override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val qualifiersString =
|
||||||
|
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
|
||||||
|
s"$qualifiersString`$name`"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
||||||
"""
|
"""
|
||||||
}.mkString("\n")
|
}.mkString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
|
||||||
ev.value = eval.isNull
|
ev.value = eval.isNull
|
||||||
eval.code
|
eval.code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(${child.sql} IS NULL)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
||||||
ev.value = s"(!(${eval.isNull}))"
|
ev.value = s"(!(${eval.isNull}))"
|
||||||
eval.code
|
eval.code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(${child.sql} IS NOT NULL)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,8 @@ case class Not(child: Expression)
|
||||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||||
defineCodeGen(ctx, ev, c => s"!($c)")
|
defineCodeGen(ctx, ev, c => s"!($c)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(NOT ${child.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val childrenSQL = children.map(_.sql)
|
||||||
|
val valueSQL = childrenSQL.head
|
||||||
|
val listSQL = childrenSQL.tail.mkString(", ")
|
||||||
|
s"($valueSQL IN ($listSQL))"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val valueSQL = child.sql
|
||||||
|
val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")
|
||||||
|
s"($valueSQL IN ($listSQL))"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
|
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
|
||||||
|
@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(${left.sql} AND ${right.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"(${left.sql} OR ${right.sql})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic {
|
||||||
override def nullable: Boolean = false
|
override def nullable: Boolean = false
|
||||||
|
|
||||||
override def dataType: DataType = DoubleType
|
override def dataType: DataType = DoubleType
|
||||||
|
|
||||||
|
// NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
|
||||||
|
override def sql: String = s"$prettyName($seed)"
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
|
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
|
||||||
|
|
|
@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
|
||||||
matches(regex, input1.asInstanceOf[UTF8String].toString)
|
matches(regex, input1.asInstanceOf[UTF8String].toString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap}
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||||
|
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||||
|
|
||||||
|
@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression])
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
trait String2StringExpression extends ImplicitCastInputTypes {
|
trait String2StringExpression extends ImplicitCastInputTypes {
|
||||||
|
@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
|
||||||
val termDict = ctx.freshName("dict")
|
val termDict = ctx.freshName("dict")
|
||||||
val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
|
val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
|
||||||
|
|
||||||
ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
|
ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;")
|
||||||
ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
|
ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;")
|
||||||
ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
|
ctx.addMutableState(classNameDict, termDict, s"$termDict = null;")
|
||||||
|
|
||||||
nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
|
nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
|
||||||
val check = if (matchingExpr.foldable && replaceExpr.foldable) {
|
val check = if (matchingExpr.foldable && replaceExpr.foldable) {
|
||||||
s"${termDict} == null"
|
s"$termDict == null"
|
||||||
} else {
|
} else {
|
||||||
s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
|
s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)"
|
||||||
}
|
}
|
||||||
s"""if ($check) {
|
s"""if ($check) {
|
||||||
// Not all of them is literal or matching or replace value changed
|
// Not all of them is literal or matching or replace value changed
|
||||||
${termLastMatching} = ${matching}.clone();
|
$termLastMatching = $matching.clone();
|
||||||
${termLastReplace} = ${replace}.clone();
|
$termLastReplace = $replace.clone();
|
||||||
${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
|
$termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
|
||||||
.buildDict(${termLastMatching}, ${termLastReplace});
|
.buildDict($termLastMatching, $termLastReplace);
|
||||||
}
|
}
|
||||||
${ev.value} = ${src}.translate(${termDict});
|
${ev.value} = $src.translate($termDict);
|
||||||
"""
|
"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
|
||||||
}
|
}
|
||||||
|
|
||||||
override def dataType: DataType = IntegerType
|
override def dataType: DataType = IntegerType
|
||||||
|
|
||||||
|
override def prettyName: String = "find_in_set"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
|
||||||
org.apache.commons.codec.binary.Base64.encodeBase64($child));
|
org.apache.commons.codec.binary.Base64.encodeBase64($child));
|
||||||
"""})
|
"""})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
|
||||||
// SubQueries are only needed for analysis and can be removed before execution.
|
// SubQueries are only needed for analysis and can be removed before execution.
|
||||||
Batch("Remove SubQueries", FixedPoint(100),
|
Batch("Remove SubQueries", FixedPoint(100),
|
||||||
EliminateSubQueries) ::
|
EliminateSubQueries) ::
|
||||||
|
Batch("Compute Current Time", Once,
|
||||||
|
ComputeCurrentTime) ::
|
||||||
Batch("Aggregate", FixedPoint(100),
|
Batch("Aggregate", FixedPoint(100),
|
||||||
ReplaceDistinctWithAggregate,
|
ReplaceDistinctWithAggregate,
|
||||||
RemoveLiteralFromGroupExpressions) ::
|
RemoveLiteralFromGroupExpressions) ::
|
||||||
|
@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
|
||||||
)
|
)
|
||||||
Project(cleanedProjection, child)
|
Project(cleanedProjection, child)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO Eliminate duplicate code
|
||||||
|
// This clause is identical to the one above except that the inner operator is an `Aggregate`
|
||||||
|
// rather than a `Project`.
|
||||||
|
case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
|
||||||
|
// Create a map of Aliases to their values from the child projection.
|
||||||
|
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
|
||||||
|
val aliasMap = AttributeMap(projectList2.collect {
|
||||||
|
case a: Alias => (a.toAttribute, a)
|
||||||
|
})
|
||||||
|
|
||||||
|
// We only collapse these two Projects if their overlapped expressions are all
|
||||||
|
// deterministic.
|
||||||
|
val hasNondeterministic = projectList1.exists(_.collect {
|
||||||
|
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
|
||||||
|
}.exists(!_.deterministic))
|
||||||
|
|
||||||
|
if (hasNondeterministic) {
|
||||||
|
p
|
||||||
|
} else {
|
||||||
|
// Substitute any attributes that are produced by the child projection, so that we safely
|
||||||
|
// eliminate it.
|
||||||
|
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
|
||||||
|
// TODO: Fix TransformBase to avoid the cast below.
|
||||||
|
val substitutedProjection = projectList1.map(_.transform {
|
||||||
|
case a: Attribute => aliasMap.getOrElse(a, a)
|
||||||
|
}).asInstanceOf[Seq[NamedExpression]]
|
||||||
|
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
|
||||||
|
val cleanedProjection = substitutedProjection.map(p =>
|
||||||
|
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
|
||||||
|
)
|
||||||
|
agg.copy(aggregateExpressions = cleanedProjection)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
|
||||||
a.copy(groupingExpressions = newGrouping)
|
a.copy(groupingExpressions = newGrouping)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the current date and time to make sure we return the same result in a single query.
|
||||||
|
*/
|
||||||
|
object ComputeCurrentTime extends Rule[LogicalPlan] {
|
||||||
|
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||||
|
val dateExpr = CurrentDate()
|
||||||
|
val timeExpr = CurrentTimestamp()
|
||||||
|
val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
|
||||||
|
val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
|
||||||
|
|
||||||
|
plan transformAllExpressions {
|
||||||
|
case CurrentDate() => currentDate
|
||||||
|
case CurrentTimestamp() => currentTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -37,14 +37,26 @@ object JoinType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed abstract class JoinType
|
sealed abstract class JoinType {
|
||||||
|
def sql: String
|
||||||
|
}
|
||||||
|
|
||||||
case object Inner extends JoinType
|
case object Inner extends JoinType {
|
||||||
|
override def sql: String = "INNER"
|
||||||
|
}
|
||||||
|
|
||||||
case object LeftOuter extends JoinType
|
case object LeftOuter extends JoinType {
|
||||||
|
override def sql: String = "LEFT OUTER"
|
||||||
|
}
|
||||||
|
|
||||||
case object RightOuter extends JoinType
|
case object RightOuter extends JoinType {
|
||||||
|
override def sql: String = "RIGHT OUTER"
|
||||||
|
}
|
||||||
|
|
||||||
case object FullOuter extends JoinType
|
case object FullOuter extends JoinType {
|
||||||
|
override def sql: String = "FULL OUTER"
|
||||||
|
}
|
||||||
|
|
||||||
case object LeftSemi extends JoinType
|
case object LeftSemi extends JoinType {
|
||||||
|
override def sql: String = "LEFT SEMI"
|
||||||
|
}
|
||||||
|
|
|
@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
|
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
|
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ object RuleExecutor {
|
||||||
val maxSize = map.keys.map(_.toString.length).max
|
val maxSize = map.keys.map(_.toString.length).max
|
||||||
map.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
|
map.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
|
||||||
s"${k.padTo(maxSize, " ").mkString} $v"
|
s"${k.padTo(maxSize, " ").mkString} $v"
|
||||||
}.mkString("\n")
|
}.mkString("\n", "\n", "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -130,6 +130,20 @@ package object util {
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
|
||||||
|
*/
|
||||||
|
def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
|
||||||
|
case xs if xs.isEmpty =>
|
||||||
|
Option(Seq.empty[T])
|
||||||
|
|
||||||
|
case xs =>
|
||||||
|
for {
|
||||||
|
head <- xs.head
|
||||||
|
tail <- sequenceOption(xs.tail)
|
||||||
|
} yield head +: tail
|
||||||
|
}
|
||||||
|
|
||||||
/* FIX ME
|
/* FIX ME
|
||||||
implicit class debugLogging(a: Any) {
|
implicit class debugLogging(a: Any) {
|
||||||
def debugLogging() {
|
def debugLogging() {
|
||||||
|
|
|
@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
|
||||||
|
|
||||||
override def simpleString: String = s"array<${elementType.simpleString}>"
|
override def simpleString: String = s"array<${elementType.simpleString}>"
|
||||||
|
|
||||||
|
override def sql: String = s"ARRAY<${elementType.sql}>"
|
||||||
|
|
||||||
override private[spark] def asNullable: ArrayType =
|
override private[spark] def asNullable: ArrayType =
|
||||||
ArrayType(elementType.asNullable, containsNull = true)
|
ArrayType(elementType.asNullable, containsNull = true)
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType {
|
||||||
/** Readable string representation for the type with truncation */
|
/** Readable string representation for the type with truncation */
|
||||||
private[sql] def simpleString(maxNumberFields: Int): String = simpleString
|
private[sql] def simpleString(maxNumberFields: Int): String = simpleString
|
||||||
|
|
||||||
|
def sql: String = simpleString.toUpperCase
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if `this` and `other` are the same data type when ignoring nullability
|
* Check if `this` and `other` are the same data type when ignoring nullability
|
||||||
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
|
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
|
||||||
|
|
|
@ -62,6 +62,8 @@ case class MapType(
|
||||||
|
|
||||||
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
|
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
|
||||||
|
|
||||||
|
override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>"
|
||||||
|
|
||||||
override private[spark] def asNullable: MapType =
|
override private[spark] def asNullable: MapType =
|
||||||
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
|
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
|
||||||
|
|
||||||
|
|
|
@ -279,6 +279,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
|
||||||
s"struct<${fieldTypes.mkString(",")}>"
|
s"struct<${fieldTypes.mkString(",")}>"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = {
|
||||||
|
val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
|
||||||
|
s"STRUCT<${fieldTypes.mkString(", ")}>"
|
||||||
|
}
|
||||||
|
|
||||||
private[sql] override def simpleString(maxNumberFields: Int): String = {
|
private[sql] override def simpleString(maxNumberFields: Int): String = {
|
||||||
val builder = new StringBuilder
|
val builder = new StringBuilder
|
||||||
val fieldTypes = fields.take(maxNumberFields).map {
|
val fieldTypes = fields.take(maxNumberFields).map {
|
||||||
|
|
|
@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
|
||||||
|
|
||||||
override private[sql] def acceptsType(dataType: DataType) =
|
override private[sql] def acceptsType(dataType: DataType) =
|
||||||
this.getClass == dataType.getClass
|
this.getClass == dataType.getClass
|
||||||
|
|
||||||
|
override def sql: String = sqlType.sql
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
class AnalysisSuite extends AnalysisTest {
|
class AnalysisSuite extends AnalysisTest {
|
||||||
|
@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest {
|
||||||
checkAnalysis(plan, expected)
|
checkAnalysis(plan, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("analyzer should replace current_timestamp with literals") {
|
|
||||||
val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
|
|
||||||
LocalRelation())
|
|
||||||
|
|
||||||
val min = System.currentTimeMillis() * 1000
|
|
||||||
val plan = in.analyze.asInstanceOf[Project]
|
|
||||||
val max = (System.currentTimeMillis() + 1) * 1000
|
|
||||||
|
|
||||||
val lits = new scala.collection.mutable.ArrayBuffer[Long]
|
|
||||||
plan.transformAllExpressions { case e: Literal =>
|
|
||||||
lits += e.value.asInstanceOf[Long]
|
|
||||||
e
|
|
||||||
}
|
|
||||||
assert(lits.size == 2)
|
|
||||||
assert(lits(0) >= min && lits(0) <= max)
|
|
||||||
assert(lits(1) >= min && lits(1) <= max)
|
|
||||||
assert(lits(0) == lits(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("analyzer should replace current_date with literals") {
|
|
||||||
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
|
|
||||||
|
|
||||||
val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
|
||||||
val plan = in.analyze.asInstanceOf[Project]
|
|
||||||
val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
|
||||||
|
|
||||||
val lits = new scala.collection.mutable.ArrayBuffer[Int]
|
|
||||||
plan.transformAllExpressions { case e: Literal =>
|
|
||||||
lits += e.value.asInstanceOf[Int]
|
|
||||||
e
|
|
||||||
}
|
|
||||||
assert(lits.size == 2)
|
|
||||||
assert(lits(0) >= min && lits(0) <= max)
|
|
||||||
assert(lits(1) >= min && lits(1) <= max)
|
|
||||||
assert(lits(0) == lits(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
|
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
|
||||||
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
|
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
|
||||||
val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
|
val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
|
||||||
|
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
|
||||||
|
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||||
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
|
|
||||||
|
class ComputeCurrentTimeSuite extends PlanTest {
|
||||||
|
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||||
|
val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("analyzer should replace current_timestamp with literals") {
|
||||||
|
val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
|
||||||
|
LocalRelation())
|
||||||
|
|
||||||
|
val min = System.currentTimeMillis() * 1000
|
||||||
|
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
|
||||||
|
val max = (System.currentTimeMillis() + 1) * 1000
|
||||||
|
|
||||||
|
val lits = new scala.collection.mutable.ArrayBuffer[Long]
|
||||||
|
plan.transformAllExpressions { case e: Literal =>
|
||||||
|
lits += e.value.asInstanceOf[Long]
|
||||||
|
e
|
||||||
|
}
|
||||||
|
assert(lits.size == 2)
|
||||||
|
assert(lits(0) >= min && lits(0) <= max)
|
||||||
|
assert(lits(1) >= min && lits(1) <= max)
|
||||||
|
assert(lits(0) == lits(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("analyzer should replace current_date with literals") {
|
||||||
|
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
|
||||||
|
|
||||||
|
val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
||||||
|
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
|
||||||
|
val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
||||||
|
|
||||||
|
val lits = new scala.collection.mutable.ArrayBuffer[Int]
|
||||||
|
plan.transformAllExpressions { case e: Literal =>
|
||||||
|
lits += e.value.asInstanceOf[Int]
|
||||||
|
e
|
||||||
|
}
|
||||||
|
assert(lits.size == 2)
|
||||||
|
assert(lits(0) >= min && lits(0) <= max)
|
||||||
|
assert(lits(1) >= min && lits(1) <= max)
|
||||||
|
assert(lits(0) == lits(1))
|
||||||
|
}
|
||||||
|
}
|
|
@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest {
|
||||||
val correctAnswer =
|
val correctAnswer =
|
||||||
testRelation
|
testRelation
|
||||||
.select('a)
|
.select('a)
|
||||||
.groupBy('a)('a)
|
.groupBy('a)('a).analyze
|
||||||
.select('a).analyze
|
|
||||||
|
|
||||||
comparePlans(optimized, correctAnswer)
|
comparePlans(optimized, correctAnswer)
|
||||||
}
|
}
|
||||||
|
@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest {
|
||||||
val correctAnswer =
|
val correctAnswer =
|
||||||
testRelation
|
testRelation
|
||||||
.select('a)
|
.select('a)
|
||||||
.groupBy('a)('a as 'c)
|
.groupBy('a)('a as 'c).analyze
|
||||||
.select('c).analyze
|
|
||||||
|
|
||||||
comparePlans(optimized, correctAnswer)
|
comparePlans(optimized, correctAnswer)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
package org.apache.spark.sql.execution.datasources.parquet
|
package org.apache.spark.sql.execution.datasources.parquet
|
||||||
|
|
||||||
import java.net.URI
|
import java.net.URI
|
||||||
import java.util.{List => JList}
|
|
||||||
import java.util.logging.{Logger => JLogger}
|
import java.util.logging.{Logger => JLogger}
|
||||||
|
import java.util.{List => JList}
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
@ -32,24 +32,24 @@ import org.apache.hadoop.io.Writable
|
||||||
import org.apache.hadoop.mapreduce._
|
import org.apache.hadoop.mapreduce._
|
||||||
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
|
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
|
||||||
import org.apache.hadoop.mapreduce.task.JobContextImpl
|
import org.apache.hadoop.mapreduce.task.JobContextImpl
|
||||||
import org.apache.parquet.{Log => ApacheParquetLog}
|
|
||||||
import org.apache.parquet.filter2.predicate.FilterApi
|
import org.apache.parquet.filter2.predicate.FilterApi
|
||||||
import org.apache.parquet.hadoop._
|
import org.apache.parquet.hadoop._
|
||||||
import org.apache.parquet.hadoop.metadata.CompressionCodecName
|
import org.apache.parquet.hadoop.metadata.CompressionCodecName
|
||||||
import org.apache.parquet.hadoop.util.ContextUtil
|
import org.apache.parquet.hadoop.util.ContextUtil
|
||||||
import org.apache.parquet.schema.MessageType
|
import org.apache.parquet.schema.MessageType
|
||||||
|
import org.apache.parquet.{Log => ApacheParquetLog}
|
||||||
import org.slf4j.bridge.SLF4JBridgeHandler
|
import org.slf4j.bridge.SLF4JBridgeHandler
|
||||||
|
|
||||||
import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
|
import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
|
||||||
import org.apache.spark.sql.execution.datasources._
|
|
||||||
import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
|
import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
|
||||||
|
import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
|
||||||
|
import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
|
||||||
import org.apache.spark.sql.sources._
|
import org.apache.spark.sql.sources._
|
||||||
import org.apache.spark.sql.types.{DataType, StructType}
|
import org.apache.spark.sql.types.{DataType, StructType}
|
||||||
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||||
|
import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
|
||||||
|
|
||||||
private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
|
private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
|
||||||
|
|
||||||
|
@ -147,6 +147,12 @@ private[sql] class ParquetRelation(
|
||||||
.get(ParquetRelation.METASTORE_SCHEMA)
|
.get(ParquetRelation.METASTORE_SCHEMA)
|
||||||
.map(DataType.fromJson(_).asInstanceOf[StructType])
|
.map(DataType.fromJson(_).asInstanceOf[StructType])
|
||||||
|
|
||||||
|
// If this relation is converted from a Hive metastore table, this method returns the name of the
|
||||||
|
// original Hive metastore table.
|
||||||
|
private[sql] def metastoreTableName: Option[TableIdentifier] = {
|
||||||
|
parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier)
|
||||||
|
}
|
||||||
|
|
||||||
private lazy val metadataCache: MetadataCache = {
|
private lazy val metadataCache: MetadataCache = {
|
||||||
val meta = new MetadataCache
|
val meta = new MetadataCache
|
||||||
meta.refresh()
|
meta.refresh()
|
||||||
|
|
|
@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
||||||
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
|
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
|
||||||
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
|
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
|
||||||
|
|
||||||
def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
|
def testCases: Seq[(String, File)] = {
|
||||||
|
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
|
||||||
|
}
|
||||||
|
|
||||||
override def beforeAll() {
|
override def beforeAll() {
|
||||||
|
super.beforeAll()
|
||||||
TestHive.cacheTables = true
|
TestHive.cacheTables = true
|
||||||
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
|
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
|
||||||
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
|
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
|
||||||
|
@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
||||||
|
|
||||||
// For debugging dump some statistics about how much time was spent in various optimizer rules.
|
// For debugging dump some statistics about how much time was spent in various optimizer rules.
|
||||||
logWarning(RuleExecutor.dumpTimeSpent())
|
logWarning(RuleExecutor.dumpTimeSpent())
|
||||||
|
super.afterAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A list of tests deemed out of scope currently and thus completely disregarded. */
|
/** A list of tests deemed out of scope currently and thus completely disregarded. */
|
||||||
override def blackList = Seq(
|
override def blackList: Seq[String] = Seq(
|
||||||
// These tests use hooks that are not on the classpath and thus break all subsequent execution.
|
// These tests use hooks that are not on the classpath and thus break all subsequent execution.
|
||||||
"hook_order",
|
"hook_order",
|
||||||
"hook_context_cs",
|
"hook_context_cs",
|
||||||
|
@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
||||||
"alter_merge",
|
"alter_merge",
|
||||||
"alter_concatenate_indexed_table",
|
"alter_concatenate_indexed_table",
|
||||||
"protectmode2",
|
"protectmode2",
|
||||||
//"describe_table",
|
// "describe_table",
|
||||||
"describe_comment_nonascii",
|
"describe_comment_nonascii",
|
||||||
|
|
||||||
"create_merge_compressed",
|
"create_merge_compressed",
|
||||||
|
@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
||||||
* The set of tests that are believed to be working in catalyst. Tests not on whiteList or
|
* The set of tests that are believed to be working in catalyst. Tests not on whiteList or
|
||||||
* blacklist are implicitly marked as ignored.
|
* blacklist are implicitly marked as ignored.
|
||||||
*/
|
*/
|
||||||
override def whiteList = Seq(
|
override def whiteList: Seq[String] = Seq(
|
||||||
"add_part_exist",
|
"add_part_exist",
|
||||||
"add_part_multiple",
|
"add_part_multiple",
|
||||||
"add_partition_no_whitelist",
|
"add_partition_no_whitelist",
|
||||||
|
|
|
@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
|
||||||
TimeZone.setDefault(originalTimeZone)
|
TimeZone.setDefault(originalTimeZone)
|
||||||
Locale.setDefault(originalLocale)
|
Locale.setDefault(originalLocale)
|
||||||
TestHive.reset()
|
TestHive.reset()
|
||||||
|
super.afterAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -668,7 +668,8 @@ private[hive] object HiveQl extends SparkQl with Logging {
|
||||||
Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse(
|
Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse(
|
||||||
sys.error(s"Couldn't find function $functionName"))
|
sys.error(s"Couldn't find function $functionName"))
|
||||||
val functionClassName = functionInfo.getFunctionClass.getName
|
val functionClassName = functionInfo.getFunctionClass.getName
|
||||||
HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
|
HiveGenericUDTF(
|
||||||
|
functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
|
||||||
case other => super.nodeToGenerator(node)
|
case other => super.nodeToGenerator(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,244 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.hive
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicLong
|
||||||
|
|
||||||
|
import org.apache.spark.Logging
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
|
||||||
|
import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
|
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
|
||||||
|
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||||
|
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
|
||||||
|
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A builder class used to convert a resolved logical plan into a SQL query string. Note that this
|
||||||
|
* all resolved logical plan are convertible. They either don't have corresponding SQL
|
||||||
|
* representations (e.g. logical plans that operate on local Scala collections), or are simply not
|
||||||
|
* supported by this builder (yet).
|
||||||
|
*/
|
||||||
|
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
|
||||||
|
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
|
||||||
|
|
||||||
|
def toSQL: Option[String] = {
|
||||||
|
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
|
||||||
|
val maybeSQL = try {
|
||||||
|
toSQL(canonicalizedPlan)
|
||||||
|
} catch { case cause: UnsupportedOperationException =>
|
||||||
|
logInfo(s"Failed to build SQL query string because: ${cause.getMessage}")
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maybeSQL.isDefined) {
|
||||||
|
logDebug(
|
||||||
|
s"""Built SQL query string successfully from given logical plan:
|
||||||
|
|
|
||||||
|
|# Original logical plan:
|
||||||
|
|${logicalPlan.treeString}
|
||||||
|
|# Canonicalized logical plan:
|
||||||
|
|${canonicalizedPlan.treeString}
|
||||||
|
|# Built SQL query string:
|
||||||
|
|${maybeSQL.get}
|
||||||
|
""".stripMargin)
|
||||||
|
} else {
|
||||||
|
logDebug(
|
||||||
|
s"""Failed to build SQL query string from given logical plan:
|
||||||
|
|
|
||||||
|
|# Original logical plan:
|
||||||
|
|${logicalPlan.treeString}
|
||||||
|
|# Canonicalized logical plan:
|
||||||
|
|${canonicalizedPlan.treeString}
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
maybeSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
private def projectToSQL(
|
||||||
|
projectList: Seq[NamedExpression],
|
||||||
|
child: LogicalPlan,
|
||||||
|
isDistinct: Boolean): Option[String] = {
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
listSQL = projectList.map(_.sql).mkString(", ")
|
||||||
|
maybeFrom = child match {
|
||||||
|
case OneRowRelation => " "
|
||||||
|
case _ => " FROM "
|
||||||
|
}
|
||||||
|
distinct = if (isDistinct) " DISTINCT " else " "
|
||||||
|
} yield s"SELECT$distinct$listSQL$maybeFrom$childSQL"
|
||||||
|
}
|
||||||
|
|
||||||
|
private def aggregateToSQL(
|
||||||
|
groupingExprs: Seq[Expression],
|
||||||
|
aggExprs: Seq[Expression],
|
||||||
|
child: LogicalPlan): Option[String] = {
|
||||||
|
val aggSQL = aggExprs.map(_.sql).mkString(", ")
|
||||||
|
val groupingSQL = groupingExprs.map(_.sql).mkString(", ")
|
||||||
|
val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY "
|
||||||
|
val maybeFrom = child match {
|
||||||
|
case OneRowRelation => " "
|
||||||
|
case _ => " FROM "
|
||||||
|
}
|
||||||
|
|
||||||
|
toSQL(child).map { childSQL =>
|
||||||
|
s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def toSQL(node: LogicalPlan): Option[String] = node match {
|
||||||
|
case Distinct(Project(list, child)) =>
|
||||||
|
projectToSQL(list, child, isDistinct = true)
|
||||||
|
|
||||||
|
case Project(list, child) =>
|
||||||
|
projectToSQL(list, child, isDistinct = false)
|
||||||
|
|
||||||
|
case Aggregate(groupingExprs, aggExprs, child) =>
|
||||||
|
aggregateToSQL(groupingExprs, aggExprs, child)
|
||||||
|
|
||||||
|
case Limit(limit, child) =>
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
limitSQL = limit.sql
|
||||||
|
} yield s"$childSQL LIMIT $limitSQL"
|
||||||
|
|
||||||
|
case Filter(condition, child) =>
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
whereOrHaving = child match {
|
||||||
|
case _: Aggregate => "HAVING"
|
||||||
|
case _ => "WHERE"
|
||||||
|
}
|
||||||
|
conditionSQL = condition.sql
|
||||||
|
} yield s"$childSQL $whereOrHaving $conditionSQL"
|
||||||
|
|
||||||
|
case Union(left, right) =>
|
||||||
|
for {
|
||||||
|
leftSQL <- toSQL(left)
|
||||||
|
rightSQL <- toSQL(right)
|
||||||
|
} yield s"$leftSQL UNION ALL $rightSQL"
|
||||||
|
|
||||||
|
// ParquetRelation converted from Hive metastore table
|
||||||
|
case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) =>
|
||||||
|
// There seems to be a bug related to `ParquetConversions` analysis rule. The problem is
|
||||||
|
// that, the metastore database name and table name are not always propagated to converted
|
||||||
|
// `ParquetRelation` instances via data source options. Here we use subquery alias as a
|
||||||
|
// workaround.
|
||||||
|
Some(s"`$alias`")
|
||||||
|
|
||||||
|
case Subquery(alias, child) =>
|
||||||
|
toSQL(child).map(childSQL => s"($childSQL) AS $alias")
|
||||||
|
|
||||||
|
case Join(left, right, joinType, condition) =>
|
||||||
|
for {
|
||||||
|
leftSQL <- toSQL(left)
|
||||||
|
rightSQL <- toSQL(right)
|
||||||
|
joinTypeSQL = joinType.sql
|
||||||
|
conditionSQL = condition.map(" ON " + _.sql).getOrElse("")
|
||||||
|
} yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL"
|
||||||
|
|
||||||
|
case MetastoreRelation(database, table, alias) =>
|
||||||
|
val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("")
|
||||||
|
Some(s"`$database`.`$table`$aliasSQL")
|
||||||
|
|
||||||
|
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
|
||||||
|
if orders.map(_.child) == partitionExprs =>
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
|
||||||
|
} yield s"$childSQL CLUSTER BY $partitionExprsSQL"
|
||||||
|
|
||||||
|
case Sort(orders, global, child) =>
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
|
||||||
|
orderOrSort = if (global) "ORDER" else "SORT"
|
||||||
|
} yield s"$childSQL $orderOrSort BY $ordersSQL"
|
||||||
|
|
||||||
|
case RepartitionByExpression(partitionExprs, child, _) =>
|
||||||
|
for {
|
||||||
|
childSQL <- toSQL(child)
|
||||||
|
partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
|
||||||
|
} yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL"
|
||||||
|
|
||||||
|
case OneRowRelation =>
|
||||||
|
Some("")
|
||||||
|
|
||||||
|
case _ => None
|
||||||
|
}
|
||||||
|
|
||||||
|
object Canonicalizer extends RuleExecutor[LogicalPlan] {
|
||||||
|
override protected def batches: Seq[Batch] = Seq(
|
||||||
|
Batch("Canonicalizer", FixedPoint(100),
|
||||||
|
// The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
|
||||||
|
// `Aggregate`s to perform type casting. This rule merges these `Project`s into
|
||||||
|
// `Aggregate`s.
|
||||||
|
ProjectCollapsing,
|
||||||
|
|
||||||
|
// Used to handle other auxiliary `Project`s added by analyzer (e.g.
|
||||||
|
// `ResolveAggregateFunctions` rule)
|
||||||
|
RecoverScopingInfo
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
object RecoverScopingInfo extends Rule[LogicalPlan] {
|
||||||
|
override def apply(tree: LogicalPlan): LogicalPlan = tree transform {
|
||||||
|
// This branch handles aggregate functions within HAVING clauses. For example:
|
||||||
|
//
|
||||||
|
// SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
|
||||||
|
//
|
||||||
|
// This kind of query results in query plans of the following form because of analysis rule
|
||||||
|
// `ResolveAggregateFunctions`:
|
||||||
|
//
|
||||||
|
// Project ...
|
||||||
|
// +- Filter ...
|
||||||
|
// +- Aggregate ...
|
||||||
|
// +- MetastoreRelation default, src, None
|
||||||
|
case plan @ Project(_, Filter(_, _: Aggregate)) =>
|
||||||
|
wrapChildWithSubquery(plan)
|
||||||
|
|
||||||
|
case plan @ Project(_,
|
||||||
|
_: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit
|
||||||
|
) => plan
|
||||||
|
|
||||||
|
case plan: Project =>
|
||||||
|
wrapChildWithSubquery(plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
def wrapChildWithSubquery(project: Project): Project = project match {
|
||||||
|
case Project(projectList, child) =>
|
||||||
|
val alias = SQLBuilder.newSubqueryName
|
||||||
|
val childAttributes = child.outputSet
|
||||||
|
val aliasedProjectList = projectList.map(_.transform {
|
||||||
|
case a: Attribute if childAttributes.contains(a) =>
|
||||||
|
a.withQualifiers(alias :: Nil)
|
||||||
|
}.asInstanceOf[NamedExpression])
|
||||||
|
|
||||||
|
Project(aliasedProjectList, Subquery(alias, child))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object SQLBuilder {
|
||||||
|
private val nextSubqueryId = new AtomicLong(0)
|
||||||
|
|
||||||
|
private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
|
||||||
|
}
|
|
@ -17,30 +17,26 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.hive
|
package org.apache.spark.sql.hive
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.collection.mutable.ArrayBuffer
|
||||||
import scala.util.Try
|
import scala.util.Try
|
||||||
|
|
||||||
import org.apache.hadoop.hive.ql.exec._
|
import org.apache.hadoop.hive.ql.exec._
|
||||||
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
|
|
||||||
import org.apache.hadoop.hive.ql.udf.generic._
|
|
||||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
|
|
||||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
|
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
|
||||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
|
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
|
||||||
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
|
import org.apache.hadoop.hive.ql.udf.generic._
|
||||||
|
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
|
||||||
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
|
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
|
||||||
|
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
|
||||||
|
|
||||||
import org.apache.spark.Logging
|
import org.apache.spark.Logging
|
||||||
import org.apache.spark.sql.AnalysisException
|
import org.apache.spark.sql.AnalysisException
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
|
||||||
import org.apache.spark.sql.catalyst.analysis
|
|
||||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
|
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
import org.apache.spark.sql.catalyst.{InternalRow, analysis}
|
||||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
|
||||||
import org.apache.spark.sql.hive.HiveShim._
|
import org.apache.spark.sql.hive.HiveShim._
|
||||||
import org.apache.spark.sql.hive.client.ClientWrapper
|
import org.apache.spark.sql.hive.client.ClientWrapper
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry(
|
||||||
try {
|
try {
|
||||||
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
|
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
HiveGenericUDF(
|
HiveGenericUDF(
|
||||||
new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
|
name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
|
||||||
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children)
|
HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||||
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
|
HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||||
} else if (
|
} else if (
|
||||||
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
|
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
|
HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
|
||||||
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
HiveUDAFFunction(
|
HiveUDAFFunction(
|
||||||
new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
|
name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
|
||||||
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||||
val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
|
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||||
udtf.elementTypes // Force it to check input data types.
|
udtf.elementTypes // Force it to check input data types.
|
||||||
udtf
|
udtf
|
||||||
} else {
|
} else {
|
||||||
|
@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
|
private[hive] case class HiveSimpleUDF(
|
||||||
|
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
|
||||||
extends Expression with HiveInspectors with CodegenFallback with Logging {
|
extends Expression with HiveInspectors with CodegenFallback with Logging {
|
||||||
|
|
||||||
override def deterministic: Boolean = isUDFDeterministic
|
override def deterministic: Boolean = isUDFDeterministic
|
||||||
|
@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
|
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
|
||||||
|
@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
|
||||||
override def get(): AnyRef = wrap(func(), oi, dataType)
|
override def get(): AnyRef = wrap(func(), oi, dataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
|
private[hive] case class HiveGenericUDF(
|
||||||
|
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
|
||||||
extends Expression with HiveInspectors with CodegenFallback with Logging {
|
extends Expression with HiveInspectors with CodegenFallback with Logging {
|
||||||
|
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
|
@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
|
||||||
* user defined aggregations, which have clean semantics even in a partitioned execution.
|
* user defined aggregations, which have clean semantics even in a partitioned execution.
|
||||||
*/
|
*/
|
||||||
private[hive] case class HiveGenericUDTF(
|
private[hive] case class HiveGenericUDTF(
|
||||||
|
name: String,
|
||||||
funcWrapper: HiveFunctionWrapper,
|
funcWrapper: HiveFunctionWrapper,
|
||||||
children: Seq[Expression])
|
children: Seq[Expression])
|
||||||
extends Generator with HiveInspectors with CodegenFallback {
|
extends Generator with HiveInspectors with CodegenFallback {
|
||||||
|
@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF(
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF(
|
||||||
* performance a lot.
|
* performance a lot.
|
||||||
*/
|
*/
|
||||||
private[hive] case class HiveUDAFFunction(
|
private[hive] case class HiveUDAFFunction(
|
||||||
|
name: String,
|
||||||
funcWrapper: HiveFunctionWrapper,
|
funcWrapper: HiveFunctionWrapper,
|
||||||
children: Seq[Expression],
|
children: Seq[Expression],
|
||||||
isUDAFBridgeRequired: Boolean = false,
|
isUDAFBridgeRequired: Boolean = false,
|
||||||
|
@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction(
|
||||||
override def supportsPartial: Boolean = false
|
override def supportsPartial: Boolean = false
|
||||||
|
|
||||||
override val dataType: DataType = inspectorToDataType(returnInspector)
|
override val dataType: DataType = inspectorToDataType(returnInspector)
|
||||||
}
|
|
||||||
|
|
||||||
|
override def sql(isDistinct: Boolean): String = {
|
||||||
|
val distinct = if (isDistinct) "DISTINCT " else " "
|
||||||
|
s"$name($distinct${children.map(_.sql).mkString(", ")})"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.hive
|
||||||
|
|
||||||
|
import java.sql.Timestamp
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
|
||||||
|
|
||||||
|
class ExpressionSQLBuilderSuite extends SQLBuilderTest {
|
||||||
|
test("literal") {
|
||||||
|
checkSQL(Literal("foo"), "\"foo\"")
|
||||||
|
checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
|
||||||
|
checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
|
||||||
|
checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
|
||||||
|
checkSQL(Literal(4: Int), "4")
|
||||||
|
checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
|
||||||
|
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
|
||||||
|
checkSQL(Literal(2.5D), "2.5")
|
||||||
|
checkSQL(
|
||||||
|
Literal(Timestamp.valueOf("2016-01-01 00:00:00")),
|
||||||
|
"TIMESTAMP('2016-01-01 00:00:00.0')")
|
||||||
|
// TODO tests for decimals
|
||||||
|
}
|
||||||
|
|
||||||
|
test("binary comparisons") {
|
||||||
|
checkSQL('a.int === 'b.int, "(`a` = `b`)")
|
||||||
|
checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
|
||||||
|
checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))")
|
||||||
|
|
||||||
|
checkSQL('a.int < 'b.int, "(`a` < `b`)")
|
||||||
|
checkSQL('a.int <= 'b.int, "(`a` <= `b`)")
|
||||||
|
checkSQL('a.int > 'b.int, "(`a` > `b`)")
|
||||||
|
checkSQL('a.int >= 'b.int, "(`a` >= `b`)")
|
||||||
|
|
||||||
|
checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))")
|
||||||
|
checkSQL('a.int in (1, 2), "(`a` IN (1, 2))")
|
||||||
|
|
||||||
|
checkSQL('a.int.isNull, "(`a` IS NULL)")
|
||||||
|
checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("logical operators") {
|
||||||
|
checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)")
|
||||||
|
checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)")
|
||||||
|
checkSQL(!'a.boolean, "(NOT `a`)")
|
||||||
|
checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("arithmetic expressions") {
|
||||||
|
checkSQL('a.int + 'b.int, "(`a` + `b`)")
|
||||||
|
checkSQL('a.int - 'b.int, "(`a` - `b`)")
|
||||||
|
checkSQL('a.int * 'b.int, "(`a` * `b`)")
|
||||||
|
checkSQL('a.int / 'b.int, "(`a` / `b`)")
|
||||||
|
checkSQL('a.int % 'b.int, "(`a` % `b`)")
|
||||||
|
|
||||||
|
checkSQL(-'a.int, "(-`a`)")
|
||||||
|
checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.hive
|
||||||
|
|
||||||
|
import org.apache.spark.sql.test.SQLTestUtils
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
|
||||||
|
class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
|
||||||
|
import testImplicits._
|
||||||
|
|
||||||
|
protected override def beforeAll(): Unit = {
|
||||||
|
sqlContext.range(10).write.saveAsTable("t0")
|
||||||
|
|
||||||
|
sqlContext
|
||||||
|
.range(10)
|
||||||
|
.select('id as 'key, concat(lit("val_"), 'id) as 'value)
|
||||||
|
.write
|
||||||
|
.saveAsTable("t1")
|
||||||
|
|
||||||
|
sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
|
||||||
|
}
|
||||||
|
|
||||||
|
override protected def afterAll(): Unit = {
|
||||||
|
sql("DROP TABLE IF EXISTS t0")
|
||||||
|
sql("DROP TABLE IF EXISTS t1")
|
||||||
|
sql("DROP TABLE IF EXISTS t2")
|
||||||
|
}
|
||||||
|
|
||||||
|
private def checkHiveQl(hiveQl: String): Unit = {
|
||||||
|
val df = sql(hiveQl)
|
||||||
|
val convertedSQL = new SQLBuilder(df).toSQL
|
||||||
|
|
||||||
|
if (convertedSQL.isEmpty) {
|
||||||
|
fail(
|
||||||
|
s"""Cannot convert the following HiveQL query plan back to SQL query string:
|
||||||
|
|
|
||||||
|
|# Original HiveQL query string:
|
||||||
|
|$hiveQl
|
||||||
|
|
|
||||||
|
|# Resolved query plan:
|
||||||
|
|${df.queryExecution.analyzed.treeString}
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
val sqlString = convertedSQL.get
|
||||||
|
try {
|
||||||
|
checkAnswer(sql(sqlString), df)
|
||||||
|
} catch { case cause: Throwable =>
|
||||||
|
fail(
|
||||||
|
s"""Failed to execute converted SQL string or got wrong answer:
|
||||||
|
|
|
||||||
|
|# Converted SQL query string:
|
||||||
|
|$sqlString
|
||||||
|
|
|
||||||
|
|# Original HiveQL query string:
|
||||||
|
|$hiveQl
|
||||||
|
|
|
||||||
|
|# Resolved query plan:
|
||||||
|
|${df.queryExecution.analyzed.treeString}
|
||||||
|
""".stripMargin,
|
||||||
|
cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("in") {
|
||||||
|
checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("aggregate function in having clause") {
|
||||||
|
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("aggregate function in order by clause") {
|
||||||
|
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Fix name collision introduced by ResolveAggregateFunction analysis rule
|
||||||
|
// When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
|
||||||
|
// Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query
|
||||||
|
// execution since these aliases have different expression ID. But this introduces name collision
|
||||||
|
// when converting resolved plans back to SQL query strings as expression IDs are stripped.
|
||||||
|
ignore("aggregate function in order by clause with multiple order keys") {
|
||||||
|
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("type widening in union") {
|
||||||
|
checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("case") {
|
||||||
|
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("case with else") {
|
||||||
|
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("case with key") {
|
||||||
|
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("case with key and else") {
|
||||||
|
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("select distinct without aggregate functions") {
|
||||||
|
checkHiveQl("SELECT DISTINCT id FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("cluster by") {
|
||||||
|
checkHiveQl("SELECT id FROM t0 CLUSTER BY id")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("distribute by") {
|
||||||
|
checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("distribute by with sort by") {
|
||||||
|
checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id")
|
||||||
|
}
|
||||||
|
|
||||||
|
test("distinct aggregation") {
|
||||||
|
checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Enable this
|
||||||
|
// Query plans transformed by DistinctAggregationRewriter are not recognized yet
|
||||||
|
ignore("distinct and non-distinct aggregation") {
|
||||||
|
checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.hive
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||||
|
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||||
|
import org.apache.spark.sql.{DataFrame, QueryTest}
|
||||||
|
|
||||||
|
abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
|
||||||
|
protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
|
||||||
|
val actualSQL = e.sql
|
||||||
|
try {
|
||||||
|
assert(actualSQL === expectedSQL)
|
||||||
|
} catch {
|
||||||
|
case cause: Throwable =>
|
||||||
|
fail(
|
||||||
|
s"""Wrong SQL generated for the following expression:
|
||||||
|
|
|
||||||
|
|${e.prettyName}
|
||||||
|
|
|
||||||
|
|$cause
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
|
||||||
|
val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL
|
||||||
|
|
||||||
|
if (maybeSQL.isEmpty) {
|
||||||
|
fail(
|
||||||
|
s"""Cannot convert the following logical query plan to SQL:
|
||||||
|
|
|
||||||
|
|${plan.treeString}
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
val actualSQL = maybeSQL.get
|
||||||
|
|
||||||
|
try {
|
||||||
|
assert(actualSQL === expectedSQL)
|
||||||
|
} catch {
|
||||||
|
case cause: Throwable =>
|
||||||
|
fail(
|
||||||
|
s"""Wrong SQL generated for the following logical query plan:
|
||||||
|
|
|
||||||
|
|${plan.treeString}
|
||||||
|
|
|
||||||
|
|$cause
|
||||||
|
""".stripMargin)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan))
|
||||||
|
}
|
||||||
|
|
||||||
|
protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
|
||||||
|
checkSQL(df.queryExecution.analyzed, expectedSQL)
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
|
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
|
||||||
import org.apache.spark.sql.catalyst.plans.logical._
|
import org.apache.spark.sql.catalyst.plans.logical._
|
||||||
import org.apache.spark.sql.catalyst.util._
|
import org.apache.spark.sql.catalyst.util._
|
||||||
import org.apache.spark.sql.execution.{ExplainCommand, SetCommand}
|
|
||||||
import org.apache.spark.sql.execution.datasources.DescribeCommand
|
import org.apache.spark.sql.execution.datasources.DescribeCommand
|
||||||
|
import org.apache.spark.sql.execution.{ExplainCommand, SetCommand}
|
||||||
import org.apache.spark.sql.hive.test.TestHive
|
import org.apache.spark.sql.hive.test.TestHive
|
||||||
|
import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Allows the creations of tests that execute the same query against both hive
|
* Allows the creations of tests that execute the same query against both hive
|
||||||
|
@ -130,6 +131,28 @@ abstract class HiveComparisonTest
|
||||||
new java.math.BigInteger(1, digest.digest).toString(16)
|
new java.math.BigInteger(1, digest.digest).toString(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Used for testing [[SQLBuilder]] */
|
||||||
|
private var numConvertibleQueries: Int = 0
|
||||||
|
private var numTotalQueries: Int = 0
|
||||||
|
|
||||||
|
override protected def afterAll(): Unit = {
|
||||||
|
logInfo({
|
||||||
|
val percentage = if (numTotalQueries > 0) {
|
||||||
|
numConvertibleQueries.toDouble / numTotalQueries * 100
|
||||||
|
} else {
|
||||||
|
0D
|
||||||
|
}
|
||||||
|
|
||||||
|
s"""SQLBuiler statistics:
|
||||||
|
|- Total query number: $numTotalQueries
|
||||||
|
|- Number of convertible queries: $numConvertibleQueries
|
||||||
|
|- Percentage of convertible queries: $percentage%
|
||||||
|
""".stripMargin
|
||||||
|
})
|
||||||
|
|
||||||
|
super.afterAll()
|
||||||
|
}
|
||||||
|
|
||||||
protected def prepareAnswer(
|
protected def prepareAnswer(
|
||||||
hiveQuery: TestHive.type#QueryExecution,
|
hiveQuery: TestHive.type#QueryExecution,
|
||||||
answer: Seq[String]): Seq[String] = {
|
answer: Seq[String]): Seq[String] = {
|
||||||
|
@ -372,8 +395,49 @@ abstract class HiveComparisonTest
|
||||||
|
|
||||||
// Run w/ catalyst
|
// Run w/ catalyst
|
||||||
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
|
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
|
||||||
val query = new TestHive.QueryExecution(queryString)
|
var query: TestHive.QueryExecution = null
|
||||||
try { (query, prepareAnswer(query, query.stringResult())) } catch {
|
try {
|
||||||
|
query = {
|
||||||
|
val originalQuery = new TestHive.QueryExecution(queryString)
|
||||||
|
val containsCommands = originalQuery.analyzed.collectFirst {
|
||||||
|
case _: Command => ()
|
||||||
|
case _: LogicalInsertIntoHiveTable => ()
|
||||||
|
}.nonEmpty
|
||||||
|
|
||||||
|
if (containsCommands) {
|
||||||
|
originalQuery
|
||||||
|
} else {
|
||||||
|
numTotalQueries += 1
|
||||||
|
new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql =>
|
||||||
|
numConvertibleQueries += 1
|
||||||
|
logInfo(
|
||||||
|
s"""
|
||||||
|
|### Running SQL generation round-trip test {{{
|
||||||
|
|${originalQuery.analyzed.treeString}
|
||||||
|
|Original SQL:
|
||||||
|
|$queryString
|
||||||
|
|
|
||||||
|
|Generated SQL:
|
||||||
|
|$sql
|
||||||
|
|}}}
|
||||||
|
""".stripMargin.trim)
|
||||||
|
new TestHive.QueryExecution(sql)
|
||||||
|
}.getOrElse {
|
||||||
|
logInfo(
|
||||||
|
s"""
|
||||||
|
|### Cannot convert the following logical plan back to SQL {{{
|
||||||
|
|${originalQuery.analyzed.treeString}
|
||||||
|
|Original SQL:
|
||||||
|
|$queryString
|
||||||
|
|}}}
|
||||||
|
""".stripMargin.trim)
|
||||||
|
originalQuery
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(query, prepareAnswer(query, query.stringResult()))
|
||||||
|
} catch {
|
||||||
case e: Throwable =>
|
case e: Throwable =>
|
||||||
val errorMessage =
|
val errorMessage =
|
||||||
s"""
|
s"""
|
||||||
|
|
|
@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
||||||
TimeZone.setDefault(originalTimeZone)
|
TimeZone.setDefault(originalTimeZone)
|
||||||
Locale.setDefault(originalLocale)
|
Locale.setDefault(originalLocale)
|
||||||
sql("DROP TEMPORARY FUNCTION udtf_count2")
|
sql("DROP TEMPORARY FUNCTION udtf_count2")
|
||||||
|
super.afterAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-4908: concurrent hive native commands") {
|
test("SPARK-4908: concurrent hive native commands") {
|
||||||
|
|
Loading…
Reference in a new issue