[SPARK-12799] Simplify various string output for expressions
This PR introduces several major changes: 1. Replacing `Expression.prettyString` with `Expression.sql` The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users. 1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed) Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird. Here are several examples: Expression | `prettyString` | `sql` | Note ------------------ | -------------- | ---------- | --------------- `a && b` | `a && b` | `a AND b` | `a.getField("f")` | `a[f]` | `a.f` | `a` is a struct 1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders) `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression. Author: Cheng Lian <lian@databricks.com> Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods.
This commit is contained in:
parent
d806ed3436
commit
d9efe63ecd
|
@ -1047,13 +1047,13 @@ test_that("column functions", {
|
|||
schema = c("a", "b", "c"))
|
||||
result <- collect(select(df, struct("a", "c")))
|
||||
expected <- data.frame(row.names = 1:2)
|
||||
expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)),
|
||||
expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
|
||||
listToStruct(list(a = 4L, c = 6L)))
|
||||
expect_equal(result, expected)
|
||||
|
||||
result <- collect(select(df, struct(df$a, df$b)))
|
||||
expected <- data.frame(row.names = 1:2)
|
||||
expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)),
|
||||
expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
|
||||
listToStruct(list(a = 4L, b = 5L)))
|
||||
expect_equal(result, expected)
|
||||
|
||||
|
|
|
@ -219,17 +219,17 @@ class Column(object):
|
|||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
|
||||
>>> df.select(df.r.getField("b")).show()
|
||||
+----+
|
||||
|r[b]|
|
||||
+----+
|
||||
| b|
|
||||
+----+
|
||||
+---+
|
||||
|r.b|
|
||||
+---+
|
||||
| b|
|
||||
+---+
|
||||
>>> df.select(df.r.a).show()
|
||||
+----+
|
||||
|r[a]|
|
||||
+----+
|
||||
| 1|
|
||||
+----+
|
||||
+---+
|
||||
|r.a|
|
||||
+---+
|
||||
| 1|
|
||||
+---+
|
||||
"""
|
||||
return self[name]
|
||||
|
||||
|
@ -346,12 +346,12 @@ class Column(object):
|
|||
expression is between the given columns.
|
||||
|
||||
>>> df.select(df.name, df.age.between(2, 4)).show()
|
||||
+-----+--------------------------+
|
||||
| name|((age >= 2) && (age <= 4))|
|
||||
+-----+--------------------------+
|
||||
|Alice| true|
|
||||
| Bob| false|
|
||||
+-----+--------------------------+
|
||||
+-----+---------------------------+
|
||||
| name|((age >= 2) AND (age <= 4))|
|
||||
+-----+---------------------------+
|
||||
|Alice| true|
|
||||
| Bob| false|
|
||||
+-----+---------------------------+
|
||||
"""
|
||||
return (self >= lowerBound) & (self <= upperBound)
|
||||
|
||||
|
|
|
@ -92,8 +92,8 @@ class SQLContext(object):
|
|||
>>> df.registerTempTable("allTypes")
|
||||
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
|
||||
... 'from allTypes where b and i > 0').collect()
|
||||
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
|
||||
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
|
||||
[Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
|
||||
dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
|
||||
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
|
||||
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
|
||||
"""
|
||||
|
@ -210,17 +210,17 @@ class SQLContext(object):
|
|||
|
||||
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
|
||||
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
|
||||
[Row(_c0=u'4')]
|
||||
[Row(stringLengthString(test)=u'4')]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(_c0=4)]
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(_c0=4)]
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
"""
|
||||
udf = UserDefinedFunction(f, returnType, name)
|
||||
self._ssql_ctx.udf().registerPython(name, udf._judf)
|
||||
|
|
|
@ -223,22 +223,22 @@ def coalesce(*cols):
|
|||
+----+----+
|
||||
|
||||
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
|
||||
+-------------+
|
||||
|coalesce(a,b)|
|
||||
+-------------+
|
||||
| null|
|
||||
| 1|
|
||||
| 2|
|
||||
+-------------+
|
||||
+--------------+
|
||||
|coalesce(a, b)|
|
||||
+--------------+
|
||||
| null|
|
||||
| 1|
|
||||
| 2|
|
||||
+--------------+
|
||||
|
||||
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
|
||||
+----+----+---------------+
|
||||
| a| b|coalesce(a,0.0)|
|
||||
+----+----+---------------+
|
||||
|null|null| 0.0|
|
||||
| 1|null| 1.0|
|
||||
|null| 2| 0.0|
|
||||
+----+----+---------------+
|
||||
+----+----+----------------+
|
||||
| a| b|coalesce(a, 0.0)|
|
||||
+----+----+----------------+
|
||||
|null|null| 0.0|
|
||||
| 1|null| 1.0|
|
||||
|null| 2| 0.0|
|
||||
+----+----+----------------+
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column))
|
||||
|
@ -1528,7 +1528,7 @@ def array_contains(col, value):
|
|||
|
||||
>>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
|
||||
>>> df.select(array_contains(df.data, "a")).collect()
|
||||
[Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
|
||||
[Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
|
||||
import org.apache.spark.sql.catalyst.util.usePrettyExpression
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -165,7 +166,8 @@ class Analyzer(
|
|||
case e if !e.resolved => u
|
||||
case g: Generator => MultiAlias(g, Nil)
|
||||
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
|
||||
case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
|
||||
case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)()
|
||||
case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))()
|
||||
}
|
||||
}
|
||||
}.asInstanceOf[Seq[NamedExpression]]
|
||||
|
@ -328,7 +330,7 @@ class Analyzer(
|
|||
throw new AnalysisException(
|
||||
s"Aggregate expression required for pivot, found '$aggregate'")
|
||||
}
|
||||
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
|
||||
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
|
||||
Alias(filteredAggregate, name)()
|
||||
}
|
||||
}
|
||||
|
@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
|
|||
*/
|
||||
object ResolveUpCast extends Rule[LogicalPlan] {
|
||||
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
|
||||
throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
|
||||
throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
|
||||
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
|
||||
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
|
||||
"You can either add an explicit cast to the input data or choose a higher precision " +
|
||||
|
|
|
@ -57,13 +57,13 @@ trait CheckAnalysis {
|
|||
operator transformExpressionsUp {
|
||||
case a: Attribute if !a.resolved =>
|
||||
val from = operator.inputSet.map(_.name).mkString(", ")
|
||||
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]")
|
||||
a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")
|
||||
|
||||
case e: Expression if e.checkInputDataTypes().isFailure =>
|
||||
e.checkInputDataTypes() match {
|
||||
case TypeCheckResult.TypeCheckFailure(message) =>
|
||||
e.failAnalysis(
|
||||
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
|
||||
s"cannot resolve '${e.sql}' due to data type mismatch: $message")
|
||||
}
|
||||
|
||||
case c: Cast if !c.resolved =>
|
||||
|
@ -106,12 +106,12 @@ trait CheckAnalysis {
|
|||
operator match {
|
||||
case f: Filter if f.condition.dataType != BooleanType =>
|
||||
failAnalysis(
|
||||
s"filter expression '${f.condition.prettyString}' " +
|
||||
s"filter expression '${f.condition.sql}' " +
|
||||
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
|
||||
|
||||
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
|
||||
failAnalysis(
|
||||
s"join condition '${condition.prettyString}' " +
|
||||
s"join condition '${condition.sql}' " +
|
||||
s"of type ${condition.dataType.simpleString} is not a boolean.")
|
||||
|
||||
case j @ Join(_, _, _, Some(condition)) =>
|
||||
|
@ -119,10 +119,10 @@ trait CheckAnalysis {
|
|||
case p: Predicate =>
|
||||
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
|
||||
case e if e.dataType.isInstanceOf[BinaryType] =>
|
||||
failAnalysis(s"binary type expression ${e.prettyString} cannot be used " +
|
||||
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
|
||||
"in join conditions")
|
||||
case e if e.dataType.isInstanceOf[MapType] =>
|
||||
failAnalysis(s"map type expression ${e.prettyString} cannot be used " +
|
||||
failAnalysis(s"map type expression ${e.sql} cannot be used " +
|
||||
"in join conditions")
|
||||
case _ => // OK
|
||||
}
|
||||
|
@ -144,13 +144,13 @@ trait CheckAnalysis {
|
|||
|
||||
if (!child.deterministic) {
|
||||
failAnalysis(
|
||||
s"nondeterministic expression ${expr.prettyString} should not " +
|
||||
s"nondeterministic expression ${expr.sql} should not " +
|
||||
s"appear in the arguments of an aggregate function.")
|
||||
}
|
||||
}
|
||||
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
|
||||
failAnalysis(
|
||||
s"expression '${e.prettyString}' is neither present in the group by, " +
|
||||
s"expression '${e.sql}' is neither present in the group by, " +
|
||||
s"nor is it an aggregate function. " +
|
||||
"Add to group by or wrap in first() (or first_value) if you don't care " +
|
||||
"which value you get.")
|
||||
|
@ -163,7 +163,7 @@ trait CheckAnalysis {
|
|||
// Check if the data type of expr is orderable.
|
||||
if (!RowOrdering.isOrderable(expr.dataType)) {
|
||||
failAnalysis(
|
||||
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
|
||||
s"expression ${expr.sql} cannot be used as a grouping expression " +
|
||||
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
|
||||
s"data type.")
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ trait CheckAnalysis {
|
|||
// This is just a sanity check, our analysis rule PullOutNondeterministic should
|
||||
// already pull out those nondeterministic expressions and evaluate them in
|
||||
// a Project node.
|
||||
failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
|
||||
failAnalysis(s"nondeterministic expression ${expr.sql} should not " +
|
||||
s"appear in grouping expression.")
|
||||
}
|
||||
}
|
||||
|
@ -217,7 +217,7 @@ trait CheckAnalysis {
|
|||
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
|
||||
failAnalysis(
|
||||
s"""Only a single table generating function is allowed in a SELECT clause, found:
|
||||
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
|
||||
| ${exprs.map(_.sql).mkString(",")}""".stripMargin)
|
||||
|
||||
case j: Join if !j.duplicateResolved =>
|
||||
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
|
||||
|
@ -248,9 +248,9 @@ trait CheckAnalysis {
|
|||
failAnalysis(
|
||||
s"""nondeterministic expressions are only allowed in
|
||||
|Project, Filter, Aggregate or Window, found:
|
||||
| ${o.expressions.map(_.prettyString).mkString(",")}
|
||||
| ${o.expressions.map(_.sql).mkString(",")}
|
||||
|in operator ${operator.simpleString}
|
||||
""".stripMargin)
|
||||
""".stripMargin)
|
||||
|
||||
case _ => // Analysis successful!
|
||||
}
|
||||
|
|
|
@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
|
|||
new AttributeReference("gid", IntegerType, false)(isGenerated = true)
|
||||
val groupByMap = a.groupingExpressions.collect {
|
||||
case ne: NamedExpression => ne -> ne.toAttribute
|
||||
case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
|
||||
case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
|
||||
}
|
||||
val groupByAttrs = groupByMap.map(_._2)
|
||||
|
||||
|
@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
|
|||
val regularAggOperatorMap = regularAggExprs.map { e =>
|
||||
// Perform the actual aggregation in the initial aggregate.
|
||||
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
|
||||
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
|
||||
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
|
||||
|
||||
// Select the result of the first aggregate in the last aggregate.
|
||||
val result = AggregateExpression(
|
||||
|
@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
|
|||
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
|
||||
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
|
||||
// the (nulled out) input of the distinct aggregate.
|
||||
e -> new AttributeReference(e.prettyString, e.dataType, true)()
|
||||
e -> new AttributeReference(e.sql, e.dataType, true)()
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.catalyst.util.quoteIdentifier
|
||||
import org.apache.spark.sql.types.{DataType, StructType}
|
||||
|
||||
/**
|
||||
|
@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
|
|||
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
|
||||
|
||||
override def toString: String = s"'$name"
|
||||
|
||||
override def sql: String = quoteIdentifier(name)
|
||||
}
|
||||
|
||||
object UnresolvedAttribute {
|
||||
|
@ -141,11 +144,8 @@ case class UnresolvedFunction(
|
|||
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
|
||||
override lazy val resolved = false
|
||||
|
||||
override def prettyString: String = {
|
||||
s"${name}(${children.map(_.prettyString).mkString(",")})"
|
||||
}
|
||||
|
||||
override def toString: String = s"'$name(${children.mkString(",")})"
|
||||
override def prettyName: String = name
|
||||
override def toString: String = s"'$name(${children.mkString(", ")})"
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
|
|||
Alias(extract, f.name)()
|
||||
}
|
||||
|
||||
case _ => {
|
||||
case _ =>
|
||||
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
|
||||
target.get + "`")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
val from = input.inputSet.map(_.name).mkString(", ")
|
||||
|
@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
|
|||
* For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented
|
||||
* as follows:
|
||||
* MultiAlias(stack_function, Seq(a, b))
|
||||
*
|
||||
|
||||
* @param child the computation being performed
|
||||
* @param names the names to be associated with each output of computing [[child]].
|
||||
|
@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
|
|||
override lazy val resolved = false
|
||||
|
||||
override def toString: String = s"$child[$extraction]"
|
||||
override def sql: String = s"${child.sql}[${extraction.sql}]"
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds the expression that has yet to be aliased.
|
||||
*
|
||||
* @param child The computation that is needs to be resolved during analysis.
|
||||
* @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
|
||||
* @param aliasName The name if specified to be associated with the result of computing [[child]]
|
||||
*
|
||||
*/
|
||||
case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
|
||||
|
|
|
@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression {
|
|||
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
|
||||
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
|
||||
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
|
||||
s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type."
|
||||
s"however, '${child.sql}' is of ${child.dataType.simpleString} type."
|
||||
}
|
||||
|
||||
if (mismatches.isEmpty) {
|
||||
|
|
|
@ -18,10 +18,10 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||
import org.apache.spark.sql.catalyst.util.toCommentSafeString
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -96,7 +96,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
ctx.subExprEliminationExprs.get(this).map { subExprState =>
|
||||
// This expression is repeated meaning the code to evaluated has already been added
|
||||
// as a function and called in advance. Just use it.
|
||||
val code = s"/* ${this.toCommentSafeString} */"
|
||||
val code = s"/* ${toCommentSafeString(this.toString)} */"
|
||||
ExprCode(code, subExprState.isNull, subExprState.value)
|
||||
}.getOrElse {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
|
@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
ve.code = genCode(ctx, ve)
|
||||
if (ve.code != "") {
|
||||
// Add `this` in the comment.
|
||||
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
|
||||
ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim)
|
||||
} else {
|
||||
ve
|
||||
}
|
||||
|
@ -201,17 +201,6 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
*/
|
||||
def prettyName: String = getClass.getSimpleName.toLowerCase
|
||||
|
||||
/**
|
||||
* Returns a user-facing string representation of this expression, i.e. does not have developer
|
||||
* centric debugging information like the expression id.
|
||||
*/
|
||||
def prettyString: String = {
|
||||
transform {
|
||||
case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
|
||||
case u: UnresolvedAttribute => PrettyAttribute(u.name)
|
||||
}.toString
|
||||
}
|
||||
|
||||
private def flatArguments = productIterator.flatMap {
|
||||
case t: Traversable[_] => t
|
||||
case single => single :: Nil
|
||||
|
@ -219,24 +208,16 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
|
||||
override def simpleString: String = toString
|
||||
|
||||
override def toString: String = prettyName + flatArguments.mkString("(", ",", ")")
|
||||
override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")")
|
||||
|
||||
/**
|
||||
* Returns the string representation of this expression that is safe to be put in
|
||||
* code comments of generated code.
|
||||
* Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]],
|
||||
* this method may return an arbitrary user facing string.
|
||||
*/
|
||||
protected def toCommentSafeString: String = this.toString
|
||||
.replace("*/", "\\*\\/")
|
||||
.replace("\\u", "\\\\u")
|
||||
|
||||
/**
|
||||
* Returns SQL representation of this expression. For expressions that don't have a SQL
|
||||
* representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
|
||||
*/
|
||||
@throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
|
||||
def sql: String = throw new UnsupportedOperationException(
|
||||
s"Cannot map expression $this to its SQL representation"
|
||||
)
|
||||
def sql: String = {
|
||||
val childrenSQL = children.map(_.sql).mkString(", ")
|
||||
s"$prettyName($childrenSQL)"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -254,6 +235,19 @@ trait Unevaluable extends Expression {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Expressions that don't have SQL representation should extend this trait. Examples are
|
||||
* `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
|
||||
*/
|
||||
trait NonSQLExpression extends Expression {
|
||||
override def sql: String = {
|
||||
transform {
|
||||
case a: Attribute => new PrettyAttribute(a)
|
||||
}.toString
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An expression that is nondeterministic.
|
||||
*/
|
||||
|
@ -373,8 +367,6 @@ abstract class UnaryExpression extends Expression {
|
|||
"""
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = s"($prettyName(${child.sql}))"
|
||||
}
|
||||
|
||||
|
||||
|
@ -477,8 +469,6 @@ abstract class BinaryExpression extends Expression {
|
|||
"""
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
|
||||
}
|
||||
|
||||
|
||||
|
@ -499,6 +489,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
|
|||
|
||||
def symbol: String
|
||||
|
||||
def sqlOperator: String = symbol
|
||||
|
||||
override def toString: String = s"($left $symbol $right)"
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
|
||||
|
@ -506,17 +498,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
|
|||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
// First check whether left and right have the same type, then check if the type is acceptable.
|
||||
if (left.dataType != right.dataType) {
|
||||
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
|
||||
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
|
||||
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
|
||||
} else if (!inputType.acceptsType(left.dataType)) {
|
||||
TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," +
|
||||
TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," +
|
||||
s" not ${left.dataType.simpleString}")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = s"(${left.sql} $symbol ${right.sql})"
|
||||
override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})"
|
||||
}
|
||||
|
||||
|
||||
|
@ -623,9 +615,4 @@ abstract class TernaryExpression extends Expression {
|
|||
"""
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = {
|
||||
val childrenSQL = children.map(_.sql).mkString(", ")
|
||||
s"$prettyName($childrenSQL)"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,11 +39,11 @@ case class ScalaUDF(
|
|||
dataType: DataType,
|
||||
children: Seq[Expression],
|
||||
inputTypes: Seq[DataType] = Nil)
|
||||
extends Expression with ImplicitCastInputTypes {
|
||||
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def toString: String = s"UDF(${children.mkString(",")})"
|
||||
override def toString: String = s"UDF(${children.mkString(", ")})"
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/** The mode of an [[AggregateFunction]]. */
|
||||
|
@ -92,8 +91,6 @@ private[sql] case class AggregateExpression(
|
|||
AttributeSet(childReferences)
|
||||
}
|
||||
|
||||
override def prettyString: String = aggregateFunction.prettyString
|
||||
|
||||
override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
|
||||
|
||||
override def sql: String = aggregateFunction.sql(isDistinct)
|
||||
|
@ -168,7 +165,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
|
|||
}
|
||||
|
||||
def sql(isDistinct: Boolean): String = {
|
||||
val distinct = if (isDistinct) "DISTINCT " else " "
|
||||
val distinct = if (isDistinct) "DISTINCT " else ""
|
||||
s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -319,7 +319,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
}
|
||||
}
|
||||
|
||||
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
case class MaxOf(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NonSQLExpression {
|
||||
|
||||
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
|
||||
|
||||
override def inputType: AbstractDataType = TypeCollection.Ordered
|
||||
|
@ -373,7 +375,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
override def symbol: String = "max"
|
||||
}
|
||||
|
||||
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
case class MinOf(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NonSQLExpression {
|
||||
|
||||
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
|
||||
|
||||
override def inputType: AbstractDataType = TypeCollection.Ordered
|
||||
|
|
|
@ -123,4 +123,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
|
|||
}
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = not(input)
|
||||
|
||||
override def sql: String = s"~${child.sql}"
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
|
||||
import org.apache.spark.sql.catalyst.util.toCommentSafeString
|
||||
|
||||
/**
|
||||
* A trait that can be used to provide a fallback mode for expression code generation.
|
||||
|
@ -37,7 +38,7 @@ trait CodegenFallback extends Expression {
|
|||
val objectTerm = ctx.freshName("obj")
|
||||
if (nullable) {
|
||||
s"""
|
||||
/* expression: ${this.toCommentSafeString} */
|
||||
/* expression: ${toCommentSafeString(this.toString)} */
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
boolean ${ev.isNull} = $objectTerm == null;
|
||||
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
|
||||
|
@ -48,7 +49,7 @@ trait CodegenFallback extends Expression {
|
|||
} else {
|
||||
ev.isNull = "false"
|
||||
s"""
|
||||
/* expression: ${this.toCommentSafeString} */
|
||||
/* expression: ${toCommentSafeString(this.toString)} */
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
|
||||
"""
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -93,6 +93,8 @@ object ExtractValue {
|
|||
}
|
||||
}
|
||||
|
||||
trait ExtractValue extends Expression
|
||||
|
||||
/**
|
||||
* Returns the value of fields in the Struct `child`.
|
||||
*
|
||||
|
@ -102,13 +104,15 @@ object ExtractValue {
|
|||
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
|
||||
*/
|
||||
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
|
||||
extends UnaryExpression {
|
||||
extends UnaryExpression with ExtractValue {
|
||||
|
||||
private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
|
||||
|
||||
override def dataType: DataType = childSchema(ordinal).dataType
|
||||
override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
|
||||
override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
|
||||
override def sql: String =
|
||||
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any =
|
||||
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
|
||||
|
@ -130,12 +134,11 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the array of value of fields in the Array of Struct `child`.
|
||||
* For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array
|
||||
* elements, and returns them as a new array.
|
||||
*
|
||||
* No need to do type checking since it is handled by [[ExtractValue]].
|
||||
*/
|
||||
|
@ -144,10 +147,11 @@ case class GetArrayStructFields(
|
|||
field: StructField,
|
||||
ordinal: Int,
|
||||
numFields: Int,
|
||||
containsNull: Boolean) extends UnaryExpression {
|
||||
containsNull: Boolean) extends UnaryExpression with ExtractValue {
|
||||
|
||||
override def dataType: DataType = ArrayType(field.dataType, containsNull)
|
||||
override def toString: String = s"$child.${field.name}"
|
||||
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = {
|
||||
val array = input.asInstanceOf[ArrayData]
|
||||
|
@ -204,12 +208,13 @@ case class GetArrayStructFields(
|
|||
* We need to do type checking here as `ordinal` expression maybe unresolved.
|
||||
*/
|
||||
case class GetArrayItem(child: Expression, ordinal: Expression)
|
||||
extends BinaryExpression with ExpectsInputTypes {
|
||||
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
|
||||
|
||||
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
|
||||
|
||||
override def toString: String = s"$child[$ordinal]"
|
||||
override def sql: String = s"${child.sql}[${ordinal.sql}]"
|
||||
|
||||
override def left: Expression = child
|
||||
override def right: Expression = ordinal
|
||||
|
@ -250,7 +255,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
* We need to do type checking here as `key` expression maybe unresolved.
|
||||
*/
|
||||
case class GetMapValue(child: Expression, key: Expression)
|
||||
extends BinaryExpression with ExpectsInputTypes {
|
||||
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
|
||||
|
||||
private def keyType = child.dataType.asInstanceOf[MapType].keyType
|
||||
|
||||
|
@ -258,6 +263,7 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
|
||||
|
||||
override def toString: String = s"$child[$key]"
|
||||
override def sql: String = s"${child.sql}[${key.sql}]"
|
||||
|
||||
override def left: Expression = child
|
||||
override def right: Expression = key
|
||||
|
|
|
@ -35,7 +35,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
|
|||
TypeCheckResult.TypeCheckFailure(
|
||||
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
|
||||
} else if (trueValue.dataType != falseValue.dataType) {
|
||||
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
|
||||
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
|
||||
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
|
|
|
@ -238,37 +238,20 @@ case class Literal protected (value: Any, dataType: DataType)
|
|||
}
|
||||
|
||||
override def sql: String = (value, dataType) match {
|
||||
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
|
||||
"NULL"
|
||||
|
||||
case _ if value == null =>
|
||||
s"CAST(NULL AS ${dataType.sql})"
|
||||
|
||||
case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL"
|
||||
case _ if value == null => s"CAST(NULL AS ${dataType.sql})"
|
||||
case (v: UTF8String, StringType) =>
|
||||
// Escapes all backslashes and double quotes.
|
||||
"\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
|
||||
|
||||
case (v: Byte, ByteType) =>
|
||||
s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
|
||||
|
||||
case (v: Short, ShortType) =>
|
||||
s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
|
||||
|
||||
case (v: Long, LongType) =>
|
||||
s"CAST($v AS ${LongType.simpleString.toUpperCase})"
|
||||
|
||||
case (v: Float, FloatType) =>
|
||||
s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
|
||||
|
||||
case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
|
||||
s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
|
||||
|
||||
case (v: Int, DateType) =>
|
||||
s"DATE '${DateTimeUtils.toJavaDate(v)}'"
|
||||
|
||||
case (v: Long, TimestampType) =>
|
||||
s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
|
||||
|
||||
case (v: Byte, ByteType) => v + "Y"
|
||||
case (v: Short, ShortType) => v + "S"
|
||||
case (v: Long, LongType) => v + "L"
|
||||
// Float type doesn't have a suffix
|
||||
case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})"
|
||||
case (v: Double, DoubleType) => v + "D"
|
||||
case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})"
|
||||
case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'"
|
||||
case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
|
||||
case _ => value.toString
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String)
|
|||
override def foldable: Boolean = true
|
||||
override def nullable: Boolean = false
|
||||
override def toString: String = s"$name()"
|
||||
override def prettyName: String = name
|
||||
|
||||
override def eval(input: InternalRow): Any = c
|
||||
}
|
||||
|
@ -59,6 +60,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
|
|||
override def dataType: DataType = DoubleType
|
||||
override def nullable: Boolean = true
|
||||
override def toString: String = s"$name($child)"
|
||||
override def prettyName: String = name
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = {
|
||||
f(input.asInstanceOf[Double])
|
||||
|
@ -70,8 +72,6 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
|
|||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
|
||||
}
|
||||
|
||||
override def sql: String = s"$name(${child.sql})"
|
||||
}
|
||||
|
||||
abstract class UnaryLogExpression(f: Double => Double, name: String)
|
||||
|
@ -113,6 +113,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
|
|||
|
||||
override def toString: String = s"$name($left, $right)"
|
||||
|
||||
override def prettyName: String = name
|
||||
|
||||
override def dataType: DataType = DoubleType
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.UUID
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.quoteIdentifier
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object NamedExpression {
|
||||
|
@ -183,9 +184,9 @@ case class Alias(child: Expression, name: String)(
|
|||
|
||||
override def sql: String = {
|
||||
val qualifiersString =
|
||||
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
|
||||
if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
|
||||
val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
|
||||
s"${child.sql} AS $qualifiersString`$aliasName`"
|
||||
s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -300,9 +301,9 @@ case class AttributeReference(
|
|||
|
||||
override def sql: String = {
|
||||
val qualifiersString =
|
||||
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
|
||||
if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
|
||||
val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
|
||||
s"$qualifiersString`$attrRefName`"
|
||||
s"$qualifiersString${quoteIdentifier(attrRefName)}"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -310,10 +311,19 @@ case class AttributeReference(
|
|||
* A place holder used when printing expressions without debugging information such as the
|
||||
* expression id or the unresolved indicator.
|
||||
*/
|
||||
case class PrettyAttribute(name: String, dataType: DataType = NullType)
|
||||
case class PrettyAttribute(
|
||||
name: String,
|
||||
dataType: DataType = NullType)
|
||||
extends Attribute with Unevaluable {
|
||||
|
||||
def this(attribute: Attribute) = this(attribute.name, attribute match {
|
||||
case a: AttributeReference => a.dataType
|
||||
case a: PrettyAttribute => a.dataType
|
||||
case _ => NullType
|
||||
})
|
||||
|
||||
override def toString: String = name
|
||||
override def sql: String = toString
|
||||
|
||||
override def withNullability(newNullability: Boolean): Attribute =
|
||||
throw new UnsupportedOperationException
|
||||
|
|
|
@ -83,8 +83,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
"""
|
||||
}.mkString("\n")
|
||||
}
|
||||
|
||||
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ case class StaticInvoke(
|
|||
dataType: DataType,
|
||||
functionName: String,
|
||||
arguments: Seq[Expression] = Nil,
|
||||
propagateNull: Boolean = true) extends Expression {
|
||||
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
|
||||
|
||||
val objectName = staticObject.getName.stripSuffix("$")
|
||||
|
||||
|
@ -108,7 +108,7 @@ case class Invoke(
|
|||
targetObject: Expression,
|
||||
functionName: String,
|
||||
dataType: DataType,
|
||||
arguments: Seq[Expression] = Nil) extends Expression {
|
||||
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def children: Seq[Expression] = arguments.+:(targetObject)
|
||||
|
@ -204,7 +204,7 @@ case class NewInstance(
|
|||
arguments: Seq[Expression],
|
||||
propagateNull: Boolean,
|
||||
dataType: DataType,
|
||||
outerPointer: Option[Literal]) extends Expression {
|
||||
outerPointer: Option[Literal]) extends Expression with NonSQLExpression {
|
||||
private val className = cls.getName
|
||||
|
||||
override def nullable: Boolean = propagateNull
|
||||
|
@ -268,7 +268,7 @@ case class NewInstance(
|
|||
*/
|
||||
case class UnwrapOption(
|
||||
dataType: DataType,
|
||||
child: Expression) extends UnaryExpression with ExpectsInputTypes {
|
||||
child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
|
@ -298,7 +298,7 @@ case class UnwrapOption(
|
|||
* @param optType The type of this option.
|
||||
*/
|
||||
case class WrapOption(child: Expression, optType: DataType)
|
||||
extends UnaryExpression with ExpectsInputTypes {
|
||||
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
|
||||
|
||||
override def dataType: DataType = ObjectType(classOf[Option[_]])
|
||||
|
||||
|
@ -328,7 +328,7 @@ case class WrapOption(child: Expression, optType: DataType)
|
|||
* manually, but will instead be passed into the provided lambda function.
|
||||
*/
|
||||
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
|
||||
with Unevaluable {
|
||||
with Unevaluable with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
|
@ -368,7 +368,7 @@ object MapObjects {
|
|||
case class MapObjects private(
|
||||
loopVar: LambdaVariable,
|
||||
lambdaFunction: Expression,
|
||||
inputData: Expression) extends Expression {
|
||||
inputData: Expression) extends Expression with NonSQLExpression {
|
||||
|
||||
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
|
||||
case NullType =>
|
||||
|
@ -483,7 +483,7 @@ case class MapObjects private(
|
|||
*
|
||||
* @param children A list of expression to use as content of the external row.
|
||||
*/
|
||||
case class CreateExternalRow(children: Seq[Expression]) extends Expression {
|
||||
case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
|
||||
override def dataType: DataType = ObjectType(classOf[Row])
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
@ -516,7 +516,8 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
|
|||
* Serializes an input object using a generic serializer (Kryo or Java).
|
||||
* @param kryo if true, use Kryo. Otherwise, use Java.
|
||||
*/
|
||||
case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression {
|
||||
case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
|
||||
extends UnaryExpression with NonSQLExpression {
|
||||
|
||||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
|
||||
|
@ -558,7 +559,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary
|
|||
* @param kryo if true, use Kryo. Otherwise, use Java.
|
||||
*/
|
||||
case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
|
||||
extends UnaryExpression {
|
||||
extends UnaryExpression with NonSQLExpression {
|
||||
|
||||
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
// Code to initialize the serializer.
|
||||
|
@ -596,7 +597,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
|
|||
* Initialize a Java Bean instance by setting its field values via setters.
|
||||
*/
|
||||
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
|
||||
extends Expression {
|
||||
extends Expression with NonSQLExpression {
|
||||
|
||||
override def nullable: Boolean = beanInstance.nullable
|
||||
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
|
||||
|
@ -638,7 +639,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
|
|||
* non-null `s`, `s.i` can't be null.
|
||||
*/
|
||||
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
|
||||
extends UnaryExpression {
|
||||
extends UnaryExpression with NonSQLExpression {
|
||||
|
||||
override def dataType: DataType = child.dataType
|
||||
|
||||
|
|
|
@ -249,6 +249,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
|
|||
|
||||
override def symbol: String = "&&"
|
||||
|
||||
override def sqlOperator: String = "AND"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val input1 = left.eval(input)
|
||||
if (input1 == false) {
|
||||
|
@ -289,8 +291,6 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
|
|||
}
|
||||
"""
|
||||
}
|
||||
|
||||
override def sql: String = s"(${left.sql} AND ${right.sql})"
|
||||
}
|
||||
|
||||
|
||||
|
@ -300,6 +300,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
|
|||
|
||||
override def symbol: String = "||"
|
||||
|
||||
override def sqlOperator: String = "OR"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val input1 = left.eval(input)
|
||||
if (input1 == true) {
|
||||
|
@ -340,8 +342,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
|
|||
}
|
||||
"""
|
||||
}
|
||||
|
||||
override def sql: String = s"(${left.sql} OR ${right.sql})"
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.util.{HashMap, Locale, Map => JMap}
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
|
@ -62,8 +61,6 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
|
|||
}
|
||||
"""
|
||||
}
|
||||
|
||||
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
|
||||
|
@ -156,8 +153,6 @@ case class ConcatWs(children: Seq[Expression])
|
|||
"""
|
||||
}
|
||||
}
|
||||
|
||||
override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
trait String2StringExpression extends ImplicitCastInputTypes {
|
||||
|
|
|
@ -556,8 +556,8 @@ abstract class RankLike extends AggregateWindowFunction {
|
|||
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
|
||||
|
||||
/** Store the values of the window 'order' expressions. */
|
||||
protected val orderAttrs = children.map{ expr =>
|
||||
AttributeReference(expr.prettyString, expr.dataType)()
|
||||
protected val orderAttrs = children.map { expr =>
|
||||
AttributeReference(expr.sql, expr.dataType)()
|
||||
}
|
||||
|
||||
/** Predicate that detects if the order attributes have changed. */
|
||||
|
@ -636,7 +636,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
|
|||
|
||||
/**
|
||||
* The PercentRank function computes the percentage ranking of a value in a group of values. The
|
||||
* result the rank of the minus one divided by the total number of rows in the partitiion minus one:
|
||||
* result the rank of the minus one divided by the total number of rows in the partition minus one:
|
||||
* (r - 1) / (n - 1). If a partition only contains one row, the function will return 0.
|
||||
*
|
||||
* The PercentRank function is similar to the CumeDist function, but it uses rank values instead of
|
||||
|
|
|
@ -575,7 +575,7 @@ case class Pivot(
|
|||
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
|
||||
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
|
||||
case _ => pivotValues.flatMap{ value =>
|
||||
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
|
||||
aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import java.io._
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{NumericType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
package object util {
|
||||
|
@ -130,20 +133,32 @@ package object util {
|
|||
ret
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
|
||||
*/
|
||||
def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
|
||||
case xs if xs.isEmpty =>
|
||||
Option(Seq.empty[T])
|
||||
|
||||
case xs =>
|
||||
for {
|
||||
head <- xs.head
|
||||
tail <- sequenceOption(xs.tail)
|
||||
} yield head +: tail
|
||||
// Replaces attributes, string literals, complex type extractors with their pretty form so that
|
||||
// generated column names don't contain back-ticks or double-quotes.
|
||||
def usePrettyExpression(e: Expression): Expression = e transform {
|
||||
case a: Attribute => new PrettyAttribute(a)
|
||||
case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType)
|
||||
case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t)
|
||||
case e: GetStructField =>
|
||||
val name = e.name.getOrElse(e.childSchema(e.ordinal).name)
|
||||
PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType)
|
||||
case e: GetArrayStructFields =>
|
||||
PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
|
||||
}
|
||||
|
||||
def quoteIdentifier(name: String): String = {
|
||||
// Escapes back-ticks within the identifier name with double-back-ticks, and then quote the
|
||||
// identifier with back-ticks.
|
||||
"`" + name.replace("`", "``") + "`"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the string representation of this expression that is safe to be put in
|
||||
* code comments of generated code.
|
||||
*/
|
||||
def toCommentSafeString(str: String): String =
|
||||
str.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
|
||||
|
||||
/* FIX ME
|
||||
implicit class debugLogging(a: Any) {
|
||||
def debugLogging() {
|
||||
|
|
|
@ -64,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
|
|||
|
||||
override def toString: String = s"DecimalType($precision,$scale)"
|
||||
|
||||
override def sql: String = typeName.toUpperCase
|
||||
|
||||
/**
|
||||
* Returns whether this DecimalType is wider than `other`. If yes, it means `other`
|
||||
* can be casted into `this` safely without losing any precision or range.
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
|
|||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
|
||||
import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DataTypeParser, LegacyTypeStringParser}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -280,7 +280,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
|
|||
}
|
||||
|
||||
override def sql: String = {
|
||||
val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
|
||||
val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}")
|
||||
s"STRUCT<${fieldTypes.mkString(", ")}>"
|
||||
}
|
||||
|
||||
|
|
|
@ -127,22 +127,24 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
errorTest(
|
||||
"single invalid type, single arg",
|
||||
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
|
||||
"cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
|
||||
"'null' is of date type" :: Nil)
|
||||
"cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" ::
|
||||
"'CAST(NULL AS DATE)' is of date type" :: Nil)
|
||||
|
||||
errorTest(
|
||||
"single invalid type, second arg",
|
||||
testRelation.select(
|
||||
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
|
||||
"cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
|
||||
"'null' is of date type" :: Nil)
|
||||
"cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
|
||||
"argument 2" :: "requires int type" ::
|
||||
"'CAST(NULL AS DATE)' is of date type" :: Nil)
|
||||
|
||||
errorTest(
|
||||
"multiple invalid type",
|
||||
testRelation.select(
|
||||
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
|
||||
"cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
|
||||
"requires int type" :: "'null' is of date type" :: Nil)
|
||||
"cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
|
||||
"argument 1" :: "argument 2" :: "requires int type" ::
|
||||
"'CAST(NULL AS DATE)' is of date type" :: Nil)
|
||||
|
||||
errorTest(
|
||||
"invalid window function",
|
||||
|
@ -207,7 +209,7 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
errorTest(
|
||||
"sorting by attributes are not from grouping expressions",
|
||||
testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc),
|
||||
"cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil)
|
||||
"cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil)
|
||||
|
||||
errorTest(
|
||||
"non-boolean filters",
|
||||
|
@ -222,7 +224,7 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
errorTest(
|
||||
"missing group by",
|
||||
testRelation2.groupBy('a)('b),
|
||||
"'b'" :: "group by" :: Nil
|
||||
"'`b`'" :: "group by" :: Nil
|
||||
)
|
||||
|
||||
errorTest(
|
||||
|
@ -270,7 +272,7 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
"SPARK-9955: correct error message for aggregate",
|
||||
// When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias.
|
||||
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
|
||||
"cannot resolve 'bad_column'" :: Nil)
|
||||
"cannot resolve '`bad_column`'" :: Nil)
|
||||
|
||||
test("SPARK-6452 regression test") {
|
||||
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
|
||||
|
@ -311,7 +313,7 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
case true =>
|
||||
assertAnalysisSuccess(plan, true)
|
||||
case false =>
|
||||
assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
|
||||
assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -372,7 +374,7 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
|
||||
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
|
||||
|
||||
assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil)
|
||||
assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
|
||||
|
||||
val plan2 =
|
||||
Join(
|
||||
|
@ -386,6 +388,6 @@ class AnalysisErrorSuite extends AnalysisTest {
|
|||
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
|
||||
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
|
||||
|
||||
assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil)
|
||||
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,8 +71,17 @@ trait AnalysisTest extends PlanTest {
|
|||
val e = intercept[AnalysisException] {
|
||||
analyzer.checkAnalysis(analyzer.execute(inputPlan))
|
||||
}
|
||||
assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
|
||||
s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
|
||||
s"actually we get ${e.getMessage}")
|
||||
|
||||
if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
|
||||
fail(
|
||||
s"""Exception message should contain the following substrings:
|
||||
|
|
||||
| ${expectedErrors.mkString("\n ")}
|
||||
|
|
||||
|Actual exception message:
|
||||
|
|
||||
| ${e.getMessage}
|
||||
""".stripMargin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
assertSuccess(expr)
|
||||
}
|
||||
assert(e.getMessage.contains(
|
||||
s"cannot resolve '${expr.prettyString}' due to data type mismatch:"))
|
||||
s"cannot resolve '${expr.sql}' due to data type mismatch:"))
|
||||
assert(e.getMessage.contains(errorMessage))
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
|
||||
def assertErrorForDifferingTypes(expr: Expression): Unit = {
|
||||
assertError(expr,
|
||||
s"differing types in '${expr.prettyString}'")
|
||||
s"differing types in '${expr.sql}'")
|
||||
}
|
||||
|
||||
test("check types for unary arithmetic") {
|
||||
|
|
|
@ -142,7 +142,7 @@ class EncoderResolutionSuite extends PlanTest {
|
|||
}.message
|
||||
assert(msg2 ==
|
||||
s"""
|
||||
|Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
|
||||
|Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate
|
||||
|The type path of the target object is:
|
||||
|- field (class: "scala.Long", name: "b")
|
||||
|- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
|
|||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.DataTypeParser
|
||||
import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DataTypeParser}
|
||||
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -104,10 +104,9 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
|
||||
def this(name: String) = this(name match {
|
||||
case "*" => UnresolvedStar(None)
|
||||
case _ if name.endsWith(".*") => {
|
||||
case _ if name.endsWith(".*") =>
|
||||
val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2))
|
||||
UnresolvedStar(Some(parts))
|
||||
}
|
||||
case _ => UnresolvedAttribute.quotedString(name)
|
||||
})
|
||||
|
||||
|
@ -123,6 +122,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
// make it a NamedExpression.
|
||||
case u: UnresolvedAttribute => UnresolvedAlias(u)
|
||||
|
||||
case u: UnresolvedExtractValue => UnresolvedAlias(u)
|
||||
|
||||
case expr: NamedExpression => expr
|
||||
|
||||
// Leave an unaliased generator with an empty list of names since the analyzer will generate
|
||||
|
@ -131,7 +132,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
|
||||
case jt: JsonTuple => MultiAlias(jt, Nil)
|
||||
|
||||
case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
|
||||
case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql))
|
||||
|
||||
// If we have a top level Cast, there is a chance to give it a better alias, if there is a
|
||||
// NamedExpression under this Cast.
|
||||
|
@ -139,13 +140,13 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
|
||||
} match {
|
||||
case ne: NamedExpression => ne
|
||||
case other => Alias(expr, expr.prettyString)()
|
||||
case other => Alias(expr, usePrettyExpression(expr).sql)()
|
||||
}
|
||||
|
||||
case expr: Expression => Alias(expr, expr.prettyString)()
|
||||
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
|
||||
}
|
||||
|
||||
override def toString: String = expr.prettyString
|
||||
override def toString: String = usePrettyExpression(expr).sql
|
||||
|
||||
override def equals(that: Any): Boolean = that match {
|
||||
case that: Column => that.expr.equals(this.expr)
|
||||
|
@ -987,7 +988,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
if (extended) {
|
||||
println(expr)
|
||||
} else {
|
||||
println(expr.prettyString)
|
||||
println(expr.sql)
|
||||
}
|
||||
// scalastyle:on println
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
|
|||
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.util.usePrettyExpression
|
||||
import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
|
||||
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
|
||||
|
@ -1359,7 +1360,8 @@ class DataFrame private[sql](
|
|||
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
|
||||
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))
|
||||
|
||||
val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
|
||||
val outputCols =
|
||||
(if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList
|
||||
|
||||
val ret: Seq[Row] = if (outputCols.nonEmpty) {
|
||||
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
|
||||
import org.apache.spark.sql.catalyst.util.usePrettyExpression
|
||||
import org.apache.spark.sql.types.NumericType
|
||||
|
||||
/**
|
||||
|
@ -74,7 +75,7 @@ class GroupedData protected[sql](
|
|||
private[this] def alias(expr: Expression): NamedExpression = expr match {
|
||||
case u: UnresolvedAttribute => UnresolvedAlias(u)
|
||||
case expr: NamedExpression => expr
|
||||
case expr: Expression => Alias(expr, expr.prettyString)()
|
||||
case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
|
||||
}
|
||||
|
||||
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.util.toCommentSafeString
|
||||
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
|
||||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
|
||||
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
|
||||
|
@ -252,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
}
|
||||
|
||||
/** Codegened pipeline for:
|
||||
* ${plan.treeString.trim}
|
||||
* ${toCommentSafeString(plan.treeString.trim)}
|
||||
*/
|
||||
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
|
||||
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
|
||||
|
@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF(
|
|||
udaf: UserDefinedAggregateFunction,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0)
|
||||
extends ImperativeAggregate with Logging {
|
||||
extends ImperativeAggregate with NonSQLExpression with Logging {
|
||||
|
||||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
|
||||
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
|
||||
|
|
|
@ -544,7 +544,7 @@ private[parquet] object CatalystSchemaConverter {
|
|||
!name.matches(".*[ ,;{}()\n\t=].*"),
|
||||
s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
|
||||
|Please use alias to rename it.
|
||||
""".stripMargin.split("\n").mkString(" "))
|
||||
""".stripMargin.split("\n").mkString(" ").trim)
|
||||
}
|
||||
|
||||
def checkFieldNames(schema: StructType): StructType = {
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python
|
|||
import org.apache.spark.{Accumulator, Logging}
|
||||
import org.apache.spark.api.python.PythonBroadcast
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
/**
|
||||
|
@ -36,9 +36,12 @@ case class PythonUDF(
|
|||
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[java.util.List[Array[Byte]]],
|
||||
dataType: DataType,
|
||||
children: Seq[Expression]) extends Expression with Unevaluable with Logging {
|
||||
children: Seq[Expression])
|
||||
extends Expression with Unevaluable with NonSQLExpression with Logging {
|
||||
|
||||
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
|
||||
override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})"
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def sql: String = s"$name(${children.mkString(", ")})"
|
||||
}
|
||||
|
|
|
@ -525,7 +525,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
val e = intercept[AnalysisException] {
|
||||
ds.as[ClassData2]
|
||||
}
|
||||
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
|
||||
assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage)
|
||||
}
|
||||
|
||||
test("runtime nullability check") {
|
||||
|
|
|
@ -455,7 +455,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
|
|||
sqlContext.udf.register("div0", (x: Int) => x / 0)
|
||||
withTempPath { dir =>
|
||||
intercept[org.apache.spark.SparkException] {
|
||||
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
|
||||
sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
|
||||
}
|
||||
val path = new Path(dir.getCanonicalPath, "_temporary")
|
||||
val fs = path.getFileSystem(hadoopConfiguration)
|
||||
|
@ -479,7 +479,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
|
|||
sqlContext.udf.register("div0", (x: Int) => x / 0)
|
||||
withTempPath { dir =>
|
||||
intercept[org.apache.spark.SparkException] {
|
||||
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
|
||||
sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
|
||||
}
|
||||
val path = new Path(dir.getCanonicalPath, "_temporary")
|
||||
val fs = path.getFileSystem(hadoopConfiguration)
|
||||
|
|
|
@ -24,10 +24,11 @@ import scala.util.control.NonFatal
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
|
||||
import org.apache.spark.sql.catalyst.util.quoteIdentifier
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
|
||||
/**
|
||||
|
@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
|
|||
* supported by this builder (yet).
|
||||
*/
|
||||
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
|
||||
require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans")
|
||||
|
||||
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
|
||||
|
||||
def toSQL: String = {
|
||||
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
|
||||
try {
|
||||
canonicalizedPlan.transformAllExpressions {
|
||||
case e: NonSQLExpression =>
|
||||
throw new UnsupportedOperationException(
|
||||
s"Expression $e doesn't have a SQL representation"
|
||||
)
|
||||
case e => e
|
||||
}
|
||||
|
||||
val generatedSQL = toSQL(canonicalizedPlan)
|
||||
logDebug(
|
||||
s"""Built SQL query string successfully from given logical plan:
|
||||
|
@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|
|||
p.child match {
|
||||
// Persisted data source relation
|
||||
case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
|
||||
s"`$database`.`$table`"
|
||||
s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
|
||||
// Parentheses is not used for persisted data source relations
|
||||
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
|
||||
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
|
||||
|
@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|
|||
|
||||
case p: MetastoreRelation =>
|
||||
build(
|
||||
s"`${p.databaseName}`.`${p.tableName}`",
|
||||
p.alias.map(a => s" AS `$a`").getOrElse("")
|
||||
s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}",
|
||||
p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("")
|
||||
)
|
||||
|
||||
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
|
||||
|
@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
|
|||
* The segments are trimmed so only a single space appears in the separation.
|
||||
* For example, `build("a", " b ", " c")` becomes "a b c".
|
||||
*/
|
||||
private def build(segments: String*): String = segments.map(_.trim).mkString(" ")
|
||||
private def build(segments: String*): String =
|
||||
segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
|
||||
|
||||
private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
|
||||
build(
|
||||
|
|
|
@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.util.sequenceOption
|
||||
import org.apache.spark.sql.hive.HiveShim._
|
||||
import org.apache.spark.sql.hive.client.HiveClientImpl
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry(
|
|||
// catch the exception and throw AnalysisException instead.
|
||||
try {
|
||||
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveGenericUDF(
|
||||
val udf = HiveGenericUDF(
|
||||
name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
|
||||
udf.dataType // Force it to check input data types.
|
||||
udf
|
||||
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
udf.dataType // Force it to check input data types.
|
||||
udf
|
||||
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
udf.dataType // Force it to check input data types.
|
||||
udf
|
||||
} else if (
|
||||
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
udaf.dataType // Force it to check input data types.
|
||||
udaf
|
||||
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
HiveUDAFFunction(
|
||||
val udaf = HiveUDAFFunction(
|
||||
name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
|
||||
udaf.dataType // Force it to check input data types.
|
||||
udaf
|
||||
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
|
||||
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
|
||||
udtf.elementTypes // Force it to check input data types.
|
||||
|
@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF(
|
|||
@transient
|
||||
private lazy val conversionHelper = new ConversionHelper(method, arguments)
|
||||
|
||||
override val dataType = javaClassToDataType(method.getReturnType)
|
||||
override lazy val dataType = javaClassToDataType(method.getReturnType)
|
||||
|
||||
@transient
|
||||
lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
|
||||
|
@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF(
|
|||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||
}
|
||||
|
||||
override def prettyName: String = name
|
||||
|
||||
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
|
@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF(
|
|||
}
|
||||
|
||||
@transient
|
||||
private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
|
||||
private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
|
||||
new DeferredObjectAdapter(inspect, child.dataType)
|
||||
}.toArray[DeferredObject]
|
||||
|
||||
override val dataType: DataType = inspectorToDataType(returnInspector)
|
||||
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
returnInspector // Make sure initialized.
|
||||
|
@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF(
|
|||
var i = 0
|
||||
while (i < children.length) {
|
||||
val idx = i
|
||||
deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
|
||||
deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(
|
||||
() => {
|
||||
children(idx).eval(input)
|
||||
})
|
||||
i += 1
|
||||
}
|
||||
unwrap(function.evaluate(deferedObjects), returnInspector)
|
||||
unwrap(function.evaluate(deferredObjects), returnInspector)
|
||||
}
|
||||
|
||||
override def prettyName: String = name
|
||||
|
||||
override def toString: String = {
|
||||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||
}
|
||||
|
||||
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF(
|
|||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||
}
|
||||
|
||||
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
|
||||
override def prettyName: String = name
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction(
|
|||
|
||||
override def supportsPartial: Boolean = false
|
||||
|
||||
override val dataType: DataType = inspectorToDataType(returnInspector)
|
||||
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
|
||||
|
||||
override def prettyName: String = name
|
||||
|
||||
override def sql(isDistinct: Boolean): String = {
|
||||
val distinct = if (isDistinct) "DISTINCT " else " "
|
||||
|
|
|
@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
|
|||
test("literal") {
|
||||
checkSQL(Literal("foo"), "\"foo\"")
|
||||
checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
|
||||
checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
|
||||
checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
|
||||
checkSQL(Literal(1: Byte), "1Y")
|
||||
checkSQL(Literal(2: Short), "2S")
|
||||
checkSQL(Literal(4: Int), "4")
|
||||
checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
|
||||
checkSQL(Literal(8: Long), "8L")
|
||||
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
|
||||
checkSQL(Literal(2.5D), "2.5")
|
||||
checkSQL(Literal(2.5D), "2.5D")
|
||||
checkSQL(
|
||||
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
|
||||
// TODO tests for decimals
|
||||
}
|
||||
|
||||
test("attributes") {
|
||||
checkSQL('a.int, "`a`")
|
||||
checkSQL(Symbol("foo bar").int, "`foo bar`")
|
||||
// Keyword
|
||||
checkSQL('int.int, "`int`")
|
||||
}
|
||||
|
||||
test("binary comparisons") {
|
||||
checkSQL('a.int === 'b.int, "(`a` = `b`)")
|
||||
checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
|
||||
|
|
|
@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
|
|||
sql("DROP TABLE IF EXISTS t0")
|
||||
sql("DROP TABLE IF EXISTS t1")
|
||||
sql("DROP TABLE IF EXISTS t2")
|
||||
|
||||
sqlContext.range(10).write.saveAsTable("t0")
|
||||
|
||||
sqlContext
|
||||
|
@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("plans with non-SQL expressions") {
|
||||
sqlContext.udf.register("foo", (_: Int) * 2)
|
||||
intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
|
||||
}
|
||||
|
||||
test("named expression in column names shouldn't be quoted") {
|
||||
def checkColumnNames(query: String, expectedColNames: String*): Unit = {
|
||||
checkHiveQl(query)
|
||||
assert(sql(query).columns === expectedColNames)
|
||||
}
|
||||
|
||||
// Attributes
|
||||
checkColumnNames(
|
||||
"""SELECT * FROM (
|
||||
| SELECT 1 AS a, 2 AS b, 3 AS `we``ird`
|
||||
|) s
|
||||
""".stripMargin,
|
||||
"a", "b", "we`ird"
|
||||
)
|
||||
|
||||
checkColumnNames(
|
||||
"""SELECT x.a, y.a, x.b, y.b
|
||||
|FROM (SELECT 1 AS a, 2 AS b) x
|
||||
|INNER JOIN (SELECT 1 AS a, 2 AS b) y
|
||||
|ON x.a = y.a
|
||||
""".stripMargin,
|
||||
"a", "a", "b", "b"
|
||||
)
|
||||
|
||||
// String literal
|
||||
checkColumnNames(
|
||||
"SELECT 'foo', '\"bar\\''",
|
||||
"foo", "\"bar\'"
|
||||
)
|
||||
|
||||
// Numeric literals (should have CAST or suffixes in column names)
|
||||
checkColumnNames(
|
||||
"SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D",
|
||||
"1", "2", "3", "4", "5.1", "6.1"
|
||||
)
|
||||
|
||||
// Aliases
|
||||
checkColumnNames(
|
||||
"SELECT 1 AS a",
|
||||
"a"
|
||||
)
|
||||
|
||||
// Complex type extractors
|
||||
checkColumnNames(
|
||||
"""SELECT
|
||||
| a.f1, b[0].f1, b.f1, c["foo"], d[0]
|
||||
|FROM (
|
||||
| SELECT
|
||||
| NAMED_STRUCT("f1", 1, "f2", "foo") AS a,
|
||||
| ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b,
|
||||
| MAP("foo", 1) AS c,
|
||||
| ARRAY(1) AS d
|
||||
|) s
|
||||
""".stripMargin,
|
||||
"f1", "b[0].f1", "f1", "c[foo]", "d[0]"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton {
|
|||
|
||||
test("udf constant folding") {
|
||||
Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t")
|
||||
val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan
|
||||
val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan
|
||||
val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan
|
||||
val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
|
|
@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
|
|||
val df = sql(
|
||||
"""
|
||||
|SELECT
|
||||
| CAST(null as TINYINT),
|
||||
| CAST(null as SMALLINT),
|
||||
| CAST(null as INT),
|
||||
| CAST(null as BIGINT),
|
||||
| CAST(null as FLOAT),
|
||||
| CAST(null as DOUBLE),
|
||||
| CAST(null as DECIMAL(7,2)),
|
||||
| CAST(null as TIMESTAMP),
|
||||
| CAST(null as DATE),
|
||||
| CAST(null as STRING),
|
||||
| CAST(null as VARCHAR(10))
|
||||
| CAST(null as TINYINT) as c0,
|
||||
| CAST(null as SMALLINT) as c1,
|
||||
| CAST(null as INT) as c2,
|
||||
| CAST(null as BIGINT) as c3,
|
||||
| CAST(null as FLOAT) as c4,
|
||||
| CAST(null as DOUBLE) as c5,
|
||||
| CAST(null as DECIMAL(7,2)) as c6,
|
||||
| CAST(null as TIMESTAMP) as c7,
|
||||
| CAST(null as DATE) as c8,
|
||||
| CAST(null as STRING) as c9,
|
||||
| CAST(null as VARCHAR(10)) as c10
|
||||
|FROM orc_temp_table limit 1
|
||||
""".stripMargin)
|
||||
|
||||
|
|
|
@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
|
|||
sql(
|
||||
s"""CREATE TABLE array_of_struct
|
||||
|STORED AS PARQUET LOCATION '$path'
|
||||
|AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b'))
|
||||
|AS SELECT
|
||||
| '1st' AS a,
|
||||
| '2nd' AS b,
|
||||
| ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c
|
||||
""".stripMargin)
|
||||
|
||||
checkAnswer(
|
||||
|
|
Loading…
Reference in a new issue