[SPARK-12799] Simplify various string output for expressions

This PR introduces several major changes:

1. Replacing `Expression.prettyString` with `Expression.sql`

   The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users.

1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed)

   Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird.  Here are several examples:

   Expression         | `prettyString` | `sql`      | Note
   ------------------ | -------------- | ---------- | ---------------
   `a && b`           | `a && b`       | `a AND b`  |
   `a.getField("f")`  | `a[f]`         | `a.f`      | `a` is a struct

1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders)

   `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression.

Author: Cheng Lian <lian@databricks.com>

Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods.
This commit is contained in:
Cheng Lian 2016-02-21 22:53:15 +08:00
parent d806ed3436
commit d9efe63ecd
49 changed files with 403 additions and 279 deletions

View file

@ -1047,13 +1047,13 @@ test_that("column functions", {
schema = c("a", "b", "c"))
result <- collect(select(df, struct("a", "c")))
expected <- data.frame(row.names = 1:2)
expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)),
expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
listToStruct(list(a = 4L, c = 6L)))
expect_equal(result, expected)
result <- collect(select(df, struct(df$a, df$b)))
expected <- data.frame(row.names = 1:2)
expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)),
expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
listToStruct(list(a = 4L, b = 5L)))
expect_equal(result, expected)

View file

@ -219,17 +219,17 @@ class Column(object):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
+----+
|r[b]|
+----+
| b|
+----+
+---+
|r.b|
+---+
| b|
+---+
>>> df.select(df.r.a).show()
+----+
|r[a]|
+----+
| 1|
+----+
+---+
|r.a|
+---+
| 1|
+---+
"""
return self[name]
@ -346,12 +346,12 @@ class Column(object):
expression is between the given columns.
>>> df.select(df.name, df.age.between(2, 4)).show()
+-----+--------------------------+
| name|((age >= 2) && (age <= 4))|
+-----+--------------------------+
|Alice| true|
| Bob| false|
+-----+--------------------------+
+-----+---------------------------+
| name|((age >= 2) AND (age <= 4))|
+-----+---------------------------+
|Alice| true|
| Bob| false|
+-----+---------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)

View file

@ -92,8 +92,8 @@ class SQLContext(object):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@ -210,17 +210,17 @@ class SQLContext(object):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(_c0=u'4')]
[Row(stringLengthString(test)=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
[Row(stringLengthInt(test)=4)]
"""
udf = UserDefinedFunction(f, returnType, name)
self._ssql_ctx.udf().registerPython(name, udf._judf)

View file

@ -223,22 +223,22 @@ def coalesce(*cols):
+----+----+
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
+-------------+
|coalesce(a,b)|
+-------------+
| null|
| 1|
| 2|
+-------------+
+--------------+
|coalesce(a, b)|
+--------------+
| null|
| 1|
| 2|
+--------------+
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
+----+----+---------------+
| a| b|coalesce(a,0.0)|
+----+----+---------------+
|null|null| 0.0|
| 1|null| 1.0|
|null| 2| 0.0|
+----+----+---------------+
+----+----+----------------+
| a| b|coalesce(a, 0.0)|
+----+----+----------------+
|null|null| 0.0|
| 1|null| 1.0|
|null| 2| 0.0|
+----+----+----------------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column))
@ -1528,7 +1528,7 @@ def array_contains(col, value):
>>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
>>> df.select(array_contains(df.data, "a")).collect()
[Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.types._
/**
@ -165,7 +166,8 @@ class Analyzer(
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)()
case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
@ -328,7 +330,7 @@ class Analyzer(
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
Alias(filteredAggregate, name)()
}
}
@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +

View file

@ -57,13 +57,13 @@ trait CheckAnalysis {
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]")
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
s"cannot resolve '${e.sql}' due to data type mismatch: $message")
}
case c: Cast if !c.resolved =>
@ -106,12 +106,12 @@ trait CheckAnalysis {
operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
s"filter expression '${f.condition.prettyString}' " +
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.prettyString}' " +
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")
case j @ Join(_, _, _, Some(condition)) =>
@ -119,10 +119,10 @@ trait CheckAnalysis {
case p: Predicate =>
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
case e if e.dataType.isInstanceOf[BinaryType] =>
failAnalysis(s"binary type expression ${e.prettyString} cannot be used " +
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
"in join conditions")
case e if e.dataType.isInstanceOf[MapType] =>
failAnalysis(s"map type expression ${e.prettyString} cannot be used " +
failAnalysis(s"map type expression ${e.sql} cannot be used " +
"in join conditions")
case _ => // OK
}
@ -144,13 +144,13 @@ trait CheckAnalysis {
if (!child.deterministic) {
failAnalysis(
s"nondeterministic expression ${expr.prettyString} should not " +
s"nondeterministic expression ${expr.sql} should not " +
s"appear in the arguments of an aggregate function.")
}
}
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"expression '${e.sql}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() (or first_value) if you don't care " +
"which value you get.")
@ -163,7 +163,7 @@ trait CheckAnalysis {
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
s"data type.")
}
@ -172,7 +172,7 @@ trait CheckAnalysis {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
failAnalysis(s"nondeterministic expression ${expr.sql} should not " +
s"appear in grouping expression.")
}
}
@ -217,7 +217,7 @@ trait CheckAnalysis {
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
| ${exprs.map(_.sql).mkString(",")}""".stripMargin)
case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
@ -248,9 +248,9 @@ trait CheckAnalysis {
failAnalysis(
s"""nondeterministic expressions are only allowed in
|Project, Filter, Aggregate or Window, found:
| ${o.expressions.map(_.prettyString).mkString(",")}
| ${o.expressions.map(_.sql).mkString(",")}
|in operator ${operator.simpleString}
""".stripMargin)
""".stripMargin)
case _ => // Analysis successful!
}

View file

@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
new AttributeReference("gid", IntegerType, false)(isGenerated = true)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)
@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
e -> new AttributeReference(e.prettyString, e.dataType, true)()
e -> new AttributeReference(e.sql, e.dataType, true)()
}

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types.{DataType, StructType}
/**
@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def toString: String = s"'$name"
override def sql: String = quoteIdentifier(name)
}
object UnresolvedAttribute {
@ -141,11 +144,8 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
override def prettyString: String = {
s"${name}(${children.map(_.prettyString).mkString(",")})"
}
override def toString: String = s"'$name(${children.mkString(",")})"
override def prettyName: String = name
override def toString: String = s"'$name(${children.mkString(", ")})"
}
/**
@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
Alias(extract, f.name)()
}
case _ => {
case _ =>
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
}
}
} else {
val from = input.inputSet.map(_.name).mkString(", ")
@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
* For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented
* as follows:
* MultiAlias(stack_function, Seq(a, b))
*
* @param child the computation being performed
* @param names the names to be associated with each output of computing [[child]].
@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override lazy val resolved = false
override def toString: String = s"$child[$extraction]"
override def sql: String = s"${child.sql}[${extraction.sql}]"
}
/**
* Holds the expression that has yet to be aliased.
*
* @param child The computation that is needs to be resolved during analysis.
* @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
* @param aliasName The name if specified to be associated with the result of computing [[child]]
*
*/
case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)

View file

@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression {
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type."
s"however, '${child.sql}' is of ${child.dataType.simpleString} type."
}
if (mismatches.isEmpty) {

View file

@ -18,10 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -96,7 +96,7 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated meaning the code to evaluated has already been added
// as a function and called in advance. Just use it.
val code = s"/* ${this.toCommentSafeString} */"
val code = s"/* ${toCommentSafeString(this.toString)} */"
ExprCode(code, subExprState.isNull, subExprState.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] {
ve.code = genCode(ctx, ve)
if (ve.code != "") {
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim)
} else {
ve
}
@ -201,17 +201,6 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyName: String = getClass.getSimpleName.toLowerCase
/**
* Returns a user-facing string representation of this expression, i.e. does not have developer
* centric debugging information like the expression id.
*/
def prettyString: String = {
transform {
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
private def flatArguments = productIterator.flatMap {
case t: Traversable[_] => t
case single => single :: Nil
@ -219,24 +208,16 @@ abstract class Expression extends TreeNode[Expression] {
override def simpleString: String = toString
override def toString: String = prettyName + flatArguments.mkString("(", ",", ")")
override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")")
/**
* Returns the string representation of this expression that is safe to be put in
* code comments of generated code.
* Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]],
* this method may return an arbitrary user facing string.
*/
protected def toCommentSafeString: String = this.toString
.replace("*/", "\\*\\/")
.replace("\\u", "\\\\u")
/**
* Returns SQL representation of this expression. For expressions that don't have a SQL
* representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
*/
@throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
def sql: String = throw new UnsupportedOperationException(
s"Cannot map expression $this to its SQL representation"
)
def sql: String = {
val childrenSQL = children.map(_.sql).mkString(", ")
s"$prettyName($childrenSQL)"
}
}
@ -254,6 +235,19 @@ trait Unevaluable extends Expression {
}
/**
* Expressions that don't have SQL representation should extend this trait. Examples are
* `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
*/
trait NonSQLExpression extends Expression {
override def sql: String = {
transform {
case a: Attribute => new PrettyAttribute(a)
}.toString
}
}
/**
* An expression that is nondeterministic.
*/
@ -373,8 +367,6 @@ abstract class UnaryExpression extends Expression {
"""
}
}
override def sql: String = s"($prettyName(${child.sql}))"
}
@ -477,8 +469,6 @@ abstract class BinaryExpression extends Expression {
"""
}
}
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
@ -499,6 +489,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
def symbol: String
def sqlOperator: String = symbol
override def toString: String = s"($left $symbol $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
@ -506,17 +498,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if the type is acceptable.
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," +
TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," +
s" not ${left.dataType.simpleString}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def sql: String = s"(${left.sql} $symbol ${right.sql})"
override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})"
}
@ -623,9 +615,4 @@ abstract class TernaryExpression extends Expression {
"""
}
}
override def sql: String = {
val childrenSQL = children.map(_.sql).mkString(", ")
s"$prettyName($childrenSQL)"
}
}

View file

@ -39,11 +39,11 @@ case class ScalaUDF(
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil)
extends Expression with ImplicitCastInputTypes {
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
override def nullable: Boolean = true
override def toString: String = s"UDF(${children.mkString(",")})"
override def toString: String = s"UDF(${children.mkString(", ")})"
// scalastyle:off line.size.limit

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
/** The mode of an [[AggregateFunction]]. */
@ -92,8 +91,6 @@ private[sql] case class AggregateExpression(
AttributeSet(childReferences)
}
override def prettyString: String = aggregateFunction.prettyString
override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
override def sql: String = aggregateFunction.sql(isDistinct)
@ -168,7 +165,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
}
def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else " "
val distinct = if (isDistinct) "DISTINCT " else ""
s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
}
}

View file

@ -319,7 +319,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
case class MaxOf(left: Expression, right: Expression)
extends BinaryArithmetic with NonSQLExpression {
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
override def inputType: AbstractDataType = TypeCollection.Ordered
@ -373,7 +375,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "max"
}
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
case class MinOf(left: Expression, right: Expression)
extends BinaryArithmetic with NonSQLExpression {
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
override def inputType: AbstractDataType = TypeCollection.Ordered

View file

@ -123,4 +123,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
}
protected override def nullSafeEval(input: Any): Any = not(input)
override def sql: String = s"~${child.sql}"
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
import org.apache.spark.sql.catalyst.util.toCommentSafeString
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@ -37,7 +38,7 @@ trait CodegenFallback extends Expression {
val objectTerm = ctx.freshName("obj")
if (nullable) {
s"""
/* expression: ${this.toCommentSafeString} */
/* expression: ${toCommentSafeString(this.toString)} */
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
@ -48,7 +49,7 @@ trait CodegenFallback extends Expression {
} else {
ev.isNull = "false"
s"""
/* expression: ${this.toCommentSafeString} */
/* expression: ${toCommentSafeString(this.toString)} */
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""

View file

@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -93,6 +93,8 @@ object ExtractValue {
}
}
trait ExtractValue extends Expression
/**
* Returns the value of fields in the Struct `child`.
*
@ -102,13 +104,15 @@ object ExtractValue {
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression {
extends UnaryExpression with ExtractValue {
private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
override def dataType: DataType = childSchema(ordinal).dataType
override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
override def sql: String =
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
@ -130,12 +134,11 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
}
})
}
override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
}
/**
* Returns the array of value of fields in the Array of Struct `child`.
* For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array
* elements, and returns them as a new array.
*
* No need to do type checking since it is handled by [[ExtractValue]].
*/
@ -144,10 +147,11 @@ case class GetArrayStructFields(
field: StructField,
ordinal: Int,
numFields: Int,
containsNull: Boolean) extends UnaryExpression {
containsNull: Boolean) extends UnaryExpression with ExtractValue {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
@ -204,12 +208,13 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes {
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
override def left: Expression = child
override def right: Expression = ordinal
@ -250,7 +255,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
extends BinaryExpression with ExpectsInputTypes {
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
private def keyType = child.dataType.asInstanceOf[MapType].keyType
@ -258,6 +263,7 @@ case class GetMapValue(child: Expression, key: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
override def left: Expression = child
override def right: Expression = key

View file

@ -35,7 +35,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess

View file

@ -238,37 +238,20 @@ case class Literal protected (value: Any, dataType: DataType)
}
override def sql: String = (value, dataType) match {
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
"NULL"
case _ if value == null =>
s"CAST(NULL AS ${dataType.sql})"
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL"
case _ if value == null => s"CAST(NULL AS ${dataType.sql})"
case (v: UTF8String, StringType) =>
// Escapes all backslashes and double quotes.
"\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
case (v: Byte, ByteType) =>
s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
case (v: Short, ShortType) =>
s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
case (v: Long, LongType) =>
s"CAST($v AS ${LongType.simpleString.toUpperCase})"
case (v: Float, FloatType) =>
s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
case (v: Int, DateType) =>
s"DATE '${DateTimeUtils.toJavaDate(v)}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
case (v: Byte, ByteType) => v + "Y"
case (v: Short, ShortType) => v + "S"
case (v: Long, LongType) => v + "L"
// Float type doesn't have a suffix
case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})"
case (v: Double, DoubleType) => v + "D"
case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})"
case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'"
case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
case _ => value.toString
}
}

View file

@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String)
override def foldable: Boolean = true
override def nullable: Boolean = false
override def toString: String = s"$name()"
override def prettyName: String = name
override def eval(input: InternalRow): Any = c
}
@ -59,6 +60,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
override def prettyName: String = name
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double])
@ -70,8 +72,6 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
override def sql: String = s"$name(${child.sql})"
}
abstract class UnaryLogExpression(f: Double => Double, name: String)
@ -113,6 +113,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
override def toString: String = s"$name($left, $right)"
override def prettyName: String = name
override def dataType: DataType = DoubleType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {

View file

@ -22,6 +22,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types._
object NamedExpression {
@ -183,9 +184,9 @@ case class Alias(child: Expression, name: String)(
override def sql: String = {
val qualifiersString =
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
s"${child.sql} AS $qualifiersString`$aliasName`"
s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}"
}
}
@ -300,9 +301,9 @@ case class AttributeReference(
override def sql: String = {
val qualifiersString =
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
s"$qualifiersString`$attrRefName`"
s"$qualifiersString${quoteIdentifier(attrRefName)}"
}
}
@ -310,10 +311,19 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
case class PrettyAttribute(name: String, dataType: DataType = NullType)
case class PrettyAttribute(
name: String,
dataType: DataType = NullType)
extends Attribute with Unevaluable {
def this(attribute: Attribute) = this(attribute.name, attribute match {
case a: AttributeReference => a.dataType
case a: PrettyAttribute => a.dataType
case _ => NullType
})
override def toString: String = name
override def sql: String = toString
override def withNullability(newNullability: Boolean): Attribute =
throw new UnsupportedOperationException

View file

@ -83,8 +83,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
"""
}.mkString("\n")
}
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}

View file

@ -46,7 +46,7 @@ case class StaticInvoke(
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression {
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
val objectName = staticObject.getName.stripSuffix("$")
@ -108,7 +108,7 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil) extends Expression {
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
override def children: Seq[Expression] = arguments.+:(targetObject)
@ -204,7 +204,7 @@ case class NewInstance(
arguments: Seq[Expression],
propagateNull: Boolean,
dataType: DataType,
outerPointer: Option[Literal]) extends Expression {
outerPointer: Option[Literal]) extends Expression with NonSQLExpression {
private val className = cls.getName
override def nullable: Boolean = propagateNull
@ -268,7 +268,7 @@ case class NewInstance(
*/
case class UnwrapOption(
dataType: DataType,
child: Expression) extends UnaryExpression with ExpectsInputTypes {
child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def nullable: Boolean = true
@ -298,7 +298,7 @@ case class UnwrapOption(
* @param optType The type of this option.
*/
case class WrapOption(child: Expression, optType: DataType)
extends UnaryExpression with ExpectsInputTypes {
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def dataType: DataType = ObjectType(classOf[Option[_]])
@ -328,7 +328,7 @@ case class WrapOption(child: Expression, optType: DataType)
* manually, but will instead be passed into the provided lambda function.
*/
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
with Unevaluable {
with Unevaluable with NonSQLExpression {
override def nullable: Boolean = true
@ -368,7 +368,7 @@ object MapObjects {
case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
inputData: Expression) extends Expression with NonSQLExpression {
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case NullType =>
@ -483,7 +483,7 @@ case class MapObjects private(
*
* @param children A list of expression to use as content of the external row.
*/
case class CreateExternalRow(children: Seq[Expression]) extends Expression {
case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
@ -516,7 +516,8 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
* Serializes an input object using a generic serializer (Kryo or Java).
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression {
case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
extends UnaryExpression with NonSQLExpression {
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
@ -558,7 +559,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
extends UnaryExpression {
extends UnaryExpression with NonSQLExpression {
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
// Code to initialize the serializer.
@ -596,7 +597,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
* Initialize a Java Bean instance by setting its field values via setters.
*/
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
extends Expression {
extends Expression with NonSQLExpression {
override def nullable: Boolean = beanInstance.nullable
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
@ -638,7 +639,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* non-null `s`, `s.i` can't be null.
*/
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {
extends UnaryExpression with NonSQLExpression {
override def dataType: DataType = child.dataType

View file

@ -249,6 +249,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
override def symbol: String = "&&"
override def sqlOperator: String = "AND"
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
@ -289,8 +291,6 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
}
"""
}
override def sql: String = s"(${left.sql} AND ${right.sql})"
}
@ -300,6 +300,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
override def symbol: String = "||"
override def sqlOperator: String = "OR"
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
@ -340,8 +342,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
}
"""
}
override def sql: String = s"(${left.sql} OR ${right.sql})"
}

View file

@ -23,7 +23,6 @@ import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -62,8 +61,6 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@ -156,8 +153,6 @@ case class ConcatWs(children: Seq[Expression])
"""
}
}
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
trait String2StringExpression extends ImplicitCastInputTypes {

View file

@ -556,8 +556,8 @@ abstract class RankLike extends AggregateWindowFunction {
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
/** Store the values of the window 'order' expressions. */
protected val orderAttrs = children.map{ expr =>
AttributeReference(expr.prettyString, expr.dataType)()
protected val orderAttrs = children.map { expr =>
AttributeReference(expr.sql, expr.dataType)()
}
/** Predicate that detects if the order attributes have changed. */
@ -636,7 +636,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
/**
* The PercentRank function computes the percentage ranking of a value in a group of values. The
* result the rank of the minus one divided by the total number of rows in the partitiion minus one:
* result the rank of the minus one divided by the total number of rows in the partition minus one:
* (r - 1) / (n - 1). If a partition only contains one row, the function will return 0.
*
* The PercentRank function is similar to the CumeDist function, but it uses rank values instead of

View file

@ -575,7 +575,7 @@ case class Pivot(
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ => pivotValues.flatMap{ value =>
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
}
}
}

View file

@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst
import java.io._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{NumericType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
package object util {
@ -130,20 +133,32 @@ package object util {
ret
}
/**
* Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
*/
def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
case xs if xs.isEmpty =>
Option(Seq.empty[T])
case xs =>
for {
head <- xs.head
tail <- sequenceOption(xs.tail)
} yield head +: tail
// Replaces attributes, string literals, complex type extractors with their pretty form so that
// generated column names don't contain back-ticks or double-quotes.
def usePrettyExpression(e: Expression): Expression = e transform {
case a: Attribute => new PrettyAttribute(a)
case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType)
case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t)
case e: GetStructField =>
val name = e.name.getOrElse(e.childSchema(e.ordinal).name)
PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType)
case e: GetArrayStructFields =>
PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
}
def quoteIdentifier(name: String): String = {
// Escapes back-ticks within the identifier name with double-back-ticks, and then quote the
// identifier with back-ticks.
"`" + name.replace("`", "``") + "`"
}
/**
* Returns the string representation of this expression that is safe to be put in
* code comments of generated code.
*/
def toCommentSafeString(str: String): String =
str.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
/* FIX ME
implicit class debugLogging(a: Any) {
def debugLogging() {

View file

@ -64,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
override def toString: String = s"DecimalType($precision,$scale)"
override def sql: String = typeName.toUpperCase
/**
* Returns whether this DecimalType is wider than `other`. If yes, it means `other`
* can be casted into `this` safely without losing any precision or range.

View file

@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DataTypeParser, LegacyTypeStringParser}
/**
* :: DeveloperApi ::
@ -280,7 +280,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
override def sql: String = {
val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}")
s"STRUCT<${fieldTypes.mkString(", ")}>"
}

View file

@ -127,22 +127,24 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
"'null' is of date type" :: Nil)
"cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" ::
"'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"single invalid type, second arg",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
"'null' is of date type" :: Nil)
"cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
"argument 2" :: "requires int type" ::
"'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"multiple invalid type",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
"requires int type" :: "'null' is of date type" :: Nil)
"cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
"argument 1" :: "argument 2" :: "requires int type" ::
"'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"invalid window function",
@ -207,7 +209,7 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"sorting by attributes are not from grouping expressions",
testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc),
"cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil)
"cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil)
errorTest(
"non-boolean filters",
@ -222,7 +224,7 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"missing group by",
testRelation2.groupBy('a)('b),
"'b'" :: "group by" :: Nil
"'`b`'" :: "group by" :: Nil
)
errorTest(
@ -270,7 +272,7 @@ class AnalysisErrorSuite extends AnalysisTest {
"SPARK-9955: correct error message for aggregate",
// When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias.
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
"cannot resolve 'bad_column'" :: Nil)
"cannot resolve '`bad_column`'" :: Nil)
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
@ -311,7 +313,7 @@ class AnalysisErrorSuite extends AnalysisTest {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil)
}
}
@ -372,7 +374,7 @@ class AnalysisErrorSuite extends AnalysisTest {
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil)
assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
val plan2 =
Join(
@ -386,6 +388,6 @@ class AnalysisErrorSuite extends AnalysisTest {
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil)
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
}
}

View file

@ -71,8 +71,17 @@ trait AnalysisTest extends PlanTest {
val e = intercept[AnalysisException] {
analyzer.checkAnalysis(analyzer.execute(inputPlan))
}
assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
s"actually we get ${e.getMessage}")
if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
fail(
s"""Exception message should contain the following substrings:
|
| ${expectedErrors.mkString("\n ")}
|
|Actual exception message:
|
| ${e.getMessage}
""".stripMargin)
}
}
}

View file

@ -41,7 +41,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(expr)
}
assert(e.getMessage.contains(
s"cannot resolve '${expr.prettyString}' due to data type mismatch:"))
s"cannot resolve '${expr.sql}' due to data type mismatch:"))
assert(e.getMessage.contains(errorMessage))
}
@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
s"differing types in '${expr.prettyString}'")
s"differing types in '${expr.sql}'")
}
test("check types for unary arithmetic") {

View file

@ -142,7 +142,7 @@ class EncoderResolutionSuite extends PlanTest {
}.message
assert(msg2 ==
s"""
|Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
|Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate
|The type path of the target object is:
|- field (class: "scala.Long", name: "b")
|- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")

View file

@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DataTypeParser}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
@ -104,10 +104,9 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") => {
case _ if name.endsWith(".*") =>
val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2))
UnresolvedStar(Some(parts))
}
case _ => UnresolvedAttribute.quotedString(name)
})
@ -123,6 +122,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
case u: UnresolvedExtractValue => UnresolvedAlias(u)
case expr: NamedExpression => expr
// Leave an unaliased generator with an empty list of names since the analyzer will generate
@ -131,7 +132,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case jt: JsonTuple => MultiAlias(jt, Nil)
case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql))
// If we have a top level Cast, there is a chance to give it a better alias, if there is a
// NamedExpression under this Cast.
@ -139,13 +140,13 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
} match {
case ne: NamedExpression => ne
case other => Alias(expr, expr.prettyString)()
case other => Alias(expr, usePrettyExpression(expr).sql)()
}
case expr: Expression => Alias(expr, expr.prettyString)()
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
override def toString: String = expr.prettyString
override def toString: String = usePrettyExpression(expr).sql
override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
@ -987,7 +988,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
if (extended) {
println(expr)
} else {
println(expr.prettyString)
println(expr.sql)
}
// scalastyle:on println
}

View file

@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
@ -1359,7 +1360,8 @@ class DataFrame private[sql](
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))
val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
val outputCols =
(if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList
val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.types.NumericType
/**
@ -74,7 +75,7 @@ class GroupedData protected[sql](
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)

View file

@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
@ -252,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
/** Codegened pipeline for:
* ${plan.treeString.trim}
* ${toCommentSafeString(plan.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with Logging {
extends ImperativeAggregate with NonSQLExpression with Logging {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

View file

@ -544,7 +544,7 @@ private[parquet] object CatalystSchemaConverter {
!name.matches(".*[ ,;{}()\n\t=].*"),
s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
|Please use alias to rename it.
""".stripMargin.split("\n").mkString(" "))
""".stripMargin.split("\n").mkString(" ").trim)
}
def checkFieldNames(schema: StructType): StructType = {

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python
import org.apache.spark.{Accumulator, Logging}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
/**
@ -36,9 +36,12 @@ case class PythonUDF(
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
accumulator: Accumulator[java.util.List[Array[Byte]]],
dataType: DataType,
children: Seq[Expression]) extends Expression with Unevaluable with Logging {
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression with Logging {
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})"
override def nullable: Boolean = true
override def sql: String = s"$name(${children.mkString(", ")})"
}

View file

@ -525,7 +525,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException] {
ds.as[ClassData2]
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage)
}
test("runtime nullability check") {

View file

@ -455,7 +455,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)
@ -479,7 +479,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)

View file

@ -24,10 +24,11 @@ import scala.util.control.NonFatal
import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
/**
@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
* supported by this builder (yet).
*/
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans")
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
def toSQL: String = {
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
try {
canonicalizedPlan.transformAllExpressions {
case e: NonSQLExpression =>
throw new UnsupportedOperationException(
s"Expression $e doesn't have a SQL representation"
)
case e => e
}
val generatedSQL = toSQL(canonicalizedPlan)
logDebug(
s"""Built SQL query string successfully from given logical plan:
@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
p.child match {
// Persisted data source relation
case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
s"`$database`.`$table`"
s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: MetastoreRelation =>
build(
s"`${p.databaseName}`.`${p.tableName}`",
p.alias.map(a => s" AS `$a`").getOrElse("")
s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}",
p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("")
)
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
* The segments are trimmed so only a single space appears in the separation.
* For example, `build("a", " b ", " c")` becomes "a b c".
*/
private def build(segments: String*): String = segments.map(_.trim).mkString(" ")
private def build(segments: String*): String =
segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(

View file

@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.types._
@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry(
// catch the exception and throw AnalysisException instead.
try {
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDF(
val udf = HiveGenericUDF(
name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
udf.dataType // Force it to check input data types.
udf
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
udf.dataType // Force it to check input data types.
udf
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
udf.dataType // Force it to check input data types.
udf
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
udaf.dataType // Force it to check input data types.
udaf
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUDAFFunction(
val udaf = HiveUDAFFunction(
name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
udaf.dataType // Force it to check input data types.
udaf
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
udtf.elementTypes // Force it to check input data types.
@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val conversionHelper = new ConversionHelper(method, arguments)
override val dataType = javaClassToDataType(method.getReturnType)
override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
override def prettyName: String = name
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF(
}
@transient
private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]
override val dataType: DataType = inspectorToDataType(returnInspector)
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.
@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF(
var i = 0
while (i < children.length) {
val idx = i
deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(
() => {
children(idx).eval(input)
})
i += 1
}
unwrap(function.evaluate(deferedObjects), returnInspector)
unwrap(function.evaluate(deferredObjects), returnInspector)
}
override def prettyName: String = name
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
override def prettyName: String = name
}
/**
@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction(
override def supportsPartial: Boolean = false
override val dataType: DataType = inspectorToDataType(returnInspector)
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override def prettyName: String = name
override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else " "

View file

@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
test("literal") {
checkSQL(Literal("foo"), "\"foo\"")
checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
checkSQL(Literal(1: Byte), "1Y")
checkSQL(Literal(2: Short), "2S")
checkSQL(Literal(4: Int), "4")
checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
checkSQL(Literal(8: Long), "8L")
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
checkSQL(Literal(2.5D), "2.5")
checkSQL(Literal(2.5D), "2.5D")
checkSQL(
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
// TODO tests for decimals
}
test("attributes") {
checkSQL('a.int, "`a`")
checkSQL(Symbol("foo bar").int, "`foo bar`")
// Keyword
checkSQL('int.int, "`int`")
}
test("binary comparisons") {
checkSQL('a.int === 'b.int, "(`a` = `b`)")
checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")

View file

@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
sql("DROP TABLE IF EXISTS t0")
sql("DROP TABLE IF EXISTS t1")
sql("DROP TABLE IF EXISTS t2")
sqlContext.range(10).write.saveAsTable("t0")
sqlContext
@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}
}
test("plans with non-SQL expressions") {
sqlContext.udf.register("foo", (_: Int) * 2)
intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
}
test("named expression in column names shouldn't be quoted") {
def checkColumnNames(query: String, expectedColNames: String*): Unit = {
checkHiveQl(query)
assert(sql(query).columns === expectedColNames)
}
// Attributes
checkColumnNames(
"""SELECT * FROM (
| SELECT 1 AS a, 2 AS b, 3 AS `we``ird`
|) s
""".stripMargin,
"a", "b", "we`ird"
)
checkColumnNames(
"""SELECT x.a, y.a, x.b, y.b
|FROM (SELECT 1 AS a, 2 AS b) x
|INNER JOIN (SELECT 1 AS a, 2 AS b) y
|ON x.a = y.a
""".stripMargin,
"a", "a", "b", "b"
)
// String literal
checkColumnNames(
"SELECT 'foo', '\"bar\\''",
"foo", "\"bar\'"
)
// Numeric literals (should have CAST or suffixes in column names)
checkColumnNames(
"SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D",
"1", "2", "3", "4", "5.1", "6.1"
)
// Aliases
checkColumnNames(
"SELECT 1 AS a",
"a"
)
// Complex type extractors
checkColumnNames(
"""SELECT
| a.f1, b[0].f1, b.f1, c["foo"], d[0]
|FROM (
| SELECT
| NAMED_STRUCT("f1", 1, "f2", "foo") AS a,
| ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b,
| MAP("foo", 1) AS c,
| ARRAY(1) AS d
|) s
""".stripMargin,
"f1", "b[0].f1", "f1", "c[foo]", "d[0]"
)
}
}

View file

@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton {
test("udf constant folding") {
Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t")
val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan
val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan
val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan
val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan
comparePlans(optimized, correctAnswer)
}

View file

@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
val df = sql(
"""
|SELECT
| CAST(null as TINYINT),
| CAST(null as SMALLINT),
| CAST(null as INT),
| CAST(null as BIGINT),
| CAST(null as FLOAT),
| CAST(null as DOUBLE),
| CAST(null as DECIMAL(7,2)),
| CAST(null as TIMESTAMP),
| CAST(null as DATE),
| CAST(null as STRING),
| CAST(null as VARCHAR(10))
| CAST(null as TINYINT) as c0,
| CAST(null as SMALLINT) as c1,
| CAST(null as INT) as c2,
| CAST(null as BIGINT) as c3,
| CAST(null as FLOAT) as c4,
| CAST(null as DOUBLE) as c5,
| CAST(null as DECIMAL(7,2)) as c6,
| CAST(null as TIMESTAMP) as c7,
| CAST(null as DATE) as c8,
| CAST(null as STRING) as c9,
| CAST(null as VARCHAR(10)) as c10
|FROM orc_temp_table limit 1
""".stripMargin)

View file

@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
sql(
s"""CREATE TABLE array_of_struct
|STORED AS PARQUET LOCATION '$path'
|AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b'))
|AS SELECT
| '1st' AS a,
| '2nd' AS b,
| ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c
""".stripMargin)
checkAnswer(