diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7b5713720d..cc118108f6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1047,13 +1047,13 @@ test_that("column functions", { schema = c("a", "b", "c")) result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) result <- collect(select(df, struct(df$a, df$b))) expected <- data.frame(row.names = 1:2) - expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 320451c52c..3866a49c0b 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -219,17 +219,17 @@ class Column(object): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ + +---+ + |r.b| + +---+ + | b| + +---+ >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ + +---+ + |r.a| + +---+ + | 1| + +---+ """ return self[name] @@ -346,12 +346,12 @@ class Column(object): expression is between the given columns. >>> df.select(df.name, df.age.between(2, 4)).show() - +-----+--------------------------+ - | name|((age >= 2) && (age <= 4))| - +-----+--------------------------+ - |Alice| true| - | Bob| false| - +-----+--------------------------+ + +-----+---------------------------+ + | name|((age >= 2) AND (age <= 4))| + +-----+---------------------------+ + |Alice| true| + | Bob| false| + +-----+---------------------------+ """ return (self >= lowerBound) & (self <= upperBound) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7f6fb410ab..89bf1443a6 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -92,8 +92,8 @@ class SQLContext(object): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ - time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ + dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -210,17 +210,17 @@ class SQLContext(object): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(_c0=u'4')] + [Row(stringLengthString(test)=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._ssql_ctx.udf().registerPython(name, udf._judf) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5fc1cc2cae..fdae05d98c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -223,22 +223,22 @@ def coalesce(*cols): +----+----+ >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() - +-------------+ - |coalesce(a,b)| - +-------------+ - | null| - | 1| - | 2| - +-------------+ + +--------------+ + |coalesce(a, b)| + +--------------+ + | null| + | 1| + | 2| + +--------------+ >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() - +----+----+---------------+ - | a| b|coalesce(a,0.0)| - +----+----+---------------+ - |null|null| 0.0| - | 1|null| 1.0| - |null| 2| 0.0| - +----+----+---------------+ + +----+----+----------------+ + | a| b|coalesce(a, 0.0)| + +----+----+----------------+ + |null|null| 0.0| + | 1|null| 1.0| + |null| 2| 0.0| + +----+----+----------------+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column)) @@ -1528,7 +1528,7 @@ def array_contains(col, value): >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() - [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba306f8b32..e153f4dd2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types._ /** @@ -165,7 +166,8 @@ class Analyzer( case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))() + case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)() + case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -328,7 +330,7 @@ class Analyzer( throw new AnalysisException( s"Aggregate expression required for pivot, found '$aggregate'") } - val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + val name = if (singleAgg) value.toString else value + "_" + aggregate.sql Alias(filteredAggregate, name)() } } @@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] { */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fe053b9a0b..1e430c1fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -57,13 +57,13 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]") + a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + s"cannot resolve '${e.sql}' due to data type mismatch: $message") } case c: Cast if !c.resolved => @@ -106,12 +106,12 @@ trait CheckAnalysis { operator match { case f: Filter if f.condition.dataType != BooleanType => failAnalysis( - s"filter expression '${f.condition.prettyString}' " + + s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( - s"join condition '${condition.prettyString}' " + + s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") case j @ Join(_, _, _, Some(condition)) => @@ -119,10 +119,10 @@ trait CheckAnalysis { case p: Predicate => p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.prettyString} cannot be used " + + failAnalysis(s"binary type expression ${e.sql} cannot be used " + "in join conditions") case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.prettyString} cannot be used " + + failAnalysis(s"map type expression ${e.sql} cannot be used " + "in join conditions") case _ => // OK } @@ -144,13 +144,13 @@ trait CheckAnalysis { if (!child.deterministic) { failAnalysis( - s"nondeterministic expression ${expr.prettyString} should not " + + s"nondeterministic expression ${expr.sql} should not " + s"appear in the arguments of an aggregate function.") } } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + + s"expression '${e.sql}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") @@ -163,7 +163,7 @@ trait CheckAnalysis { // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( - s"expression ${expr.prettyString} cannot be used as a grouping expression " + + s"expression ${expr.sql} cannot be used as a grouping expression " + s"because its data type ${expr.dataType.simpleString} is not a orderable " + s"data type.") } @@ -172,7 +172,7 @@ trait CheckAnalysis { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in // a Project node. - failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + failAnalysis(s"nondeterministic expression ${expr.sql} should not " + s"appear in grouping expression.") } } @@ -217,7 +217,7 @@ trait CheckAnalysis { case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( s"""Only a single table generating function is allowed in a SELECT clause, found: - | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + | ${exprs.map(_.sql).mkString(",")}""".stripMargin) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -248,9 +248,9 @@ trait CheckAnalysis { failAnalysis( s"""nondeterministic expressions are only allowed in |Project, Filter, Aggregate or Window, found: - | ${o.expressions.map(_.prettyString).mkString(",")} + | ${o.expressions.map(_.sql).mkString(",")} |in operator ${operator.simpleString} - """.stripMargin) + """.stripMargin) case _ => // Analysis successful! } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 5dfce89bd6..b49885d469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP new AttributeReference("gid", IntegerType, false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() + case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) - val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() + val operator = Alias(e.copy(aggregateFunction = af), e.sql)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( @@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyString, e.dataType, true)() + e -> new AttributeReference(e.sql, e.dataType, true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 79eebbf9b1..01afa01ae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.{DataType, StructType} /** @@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def toString: String = s"'$name" + + override def sql: String = quoteIdentifier(name) } object UnresolvedAttribute { @@ -141,11 +144,8 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def prettyString: String = { - s"${name}(${children.map(_.prettyString).mkString(",")})" - } - - override def toString: String = s"'$name(${children.mkString(",")})" + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" } /** @@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu Alias(extract, f.name)() } - case _ => { + case _ => throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") - } } } else { val from = input.inputSet.map(_.name).mkString(", ") @@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented * as follows: * MultiAlias(stack_function, Seq(a, b)) + * * @param child the computation being performed * @param names the names to be associated with each output of computing [[child]]. @@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override lazy val resolved = false override def toString: String = s"$child[$extraction]" + override def sql: String = s"${child.sql}[${extraction.sql}]" } /** * Holds the expression that has yet to be aliased. * * @param child The computation that is needs to be resolved during analysis. - * @param aliasName The name if specified to be asoosicated with the result of computing [[child]] + * @param aliasName The name if specified to be associated with the result of computing [[child]] * */ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 04650d85de..c7be8e886c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c73b2f8f2a..119496c7ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -96,7 +96,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* ${this.toCommentSafeString} */" + val code = s"/* ${toCommentSafeString(this.toString)} */" ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] { ve.code = genCode(ctx, ve) if (ve.code != "") { // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) } else { ve } @@ -201,17 +201,6 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyName: String = getClass.getSimpleName.toLowerCase - /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name, a.dataType) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil @@ -219,24 +208,16 @@ abstract class Expression extends TreeNode[Expression] { override def simpleString: String = toString - override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") /** - * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. + * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], + * this method may return an arbitrary user facing string. */ - protected def toCommentSafeString: String = this.toString - .replace("*/", "\\*\\/") - .replace("\\u", "\\\\u") - - /** - * Returns SQL representation of this expression. For expressions that don't have a SQL - * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. - */ - @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") - def sql: String = throw new UnsupportedOperationException( - s"Cannot map expression $this to its SQL representation" - ) + def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } @@ -254,6 +235,19 @@ trait Unevaluable extends Expression { } +/** + * Expressions that don't have SQL representation should extend this trait. Examples are + * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`. + */ +trait NonSQLExpression extends Expression { + override def sql: String = { + transform { + case a: Attribute => new PrettyAttribute(a) + }.toString + } +} + + /** * An expression that is nondeterministic. */ @@ -373,8 +367,6 @@ abstract class UnaryExpression extends Expression { """ } } - - override def sql: String = s"($prettyName(${child.sql}))" } @@ -477,8 +469,6 @@ abstract class BinaryExpression extends Expression { """ } } - - override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -499,6 +489,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { def symbol: String + def sqlOperator: String = symbol + override def toString: String = s"($left $symbol $right)" override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) @@ -506,17 +498,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { - TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," + + TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + s" not ${left.dataType.simpleString}") } else { TypeCheckResult.TypeCheckSuccess } } - override def sql: String = s"(${left.sql} $symbol ${right.sql})" + override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})" } @@ -623,9 +615,4 @@ abstract class TernaryExpression extends Expression { """ } } - - override def sql: String = { - val childrenSQL = children.map(_.sql).mkString(", ") - s"$prettyName($childrenSQL)" - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 681694746b..22184f1ddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -39,11 +39,11 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with NonSQLExpression { override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(", ")})" // scalastyle:off line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f88a57a254..ff3064ac66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -92,8 +91,6 @@ private[sql] case class AggregateExpression( AttributeSet(childReferences) } - override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" override def sql: String = aggregateFunction.sql(isDistinct) @@ -168,7 +165,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu } def sql(isDistinct: Boolean): String = { - val distinct = if (isDistinct) "DISTINCT " else " " + val distinct = if (isDistinct) "DISTINCT " else "" s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1cacd3f76a..5af234609d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -319,7 +319,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } -case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MaxOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered @@ -373,7 +375,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "max" } -case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MinOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a97bd9edce..4c90b3f7d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -123,4 +123,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } protected override def nullSafeEval(input: Any): Any = not(input) + + override def sql: String = s"~${child.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index f58a2daf90..1365ee4b55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.util.toCommentSafeString /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -37,7 +38,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") if (nullable) { s""" - /* expression: ${this.toCommentSafeString} */ + /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; @@ -48,7 +49,7 @@ trait CodegenFallback extends Expression { } else { ev.isNull = "false" s""" - /* expression: ${this.toCommentSafeString} */ + /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6b24fae9f3..44cdc8d881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,6 +93,8 @@ object ExtractValue { } } +trait ExtractValue extends Expression + /** * Returns the value of fields in the Struct `child`. * @@ -102,13 +104,15 @@ object ExtractValue { * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression { + extends UnaryExpression with ExtractValue { private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" + override def sql: String = + child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) @@ -130,12 +134,11 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } - - override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** - * Returns the array of value of fields in the Array of Struct `child`. + * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array + * elements, and returns them as a new array. * * No need to do type checking since it is handled by [[ExtractValue]]. */ @@ -144,10 +147,11 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression { + containsNull: Boolean) extends UnaryExpression with ExtractValue { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" + override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -204,12 +208,13 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) override def toString: String = s"$child[$ordinal]" + override def sql: String = s"${child.sql}[${ordinal.sql}]" override def left: Expression = child override def right: Expression = ordinal @@ -250,7 +255,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { private def keyType = child.dataType.asInstanceOf[MapType].keyType @@ -258,6 +263,7 @@ case class GetMapValue(child: Expression, key: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" override def left: Expression = child override def right: Expression = key diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1eff2c4dd0..200c6a05df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -35,7 +35,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ca0892eb42..d7d768babc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -238,37 +238,20 @@ case class Literal protected (value: Any, dataType: DataType) } override def sql: String = (value, dataType) match { - case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => - "NULL" - - case _ if value == null => - s"CAST(NULL AS ${dataType.sql})" - + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" + case _ if value == null => s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => // Escapes all backslashes and double quotes. "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" - - case (v: Byte, ByteType) => - s"CAST($v AS ${ByteType.simpleString.toUpperCase})" - - case (v: Short, ShortType) => - s"CAST($v AS ${ShortType.simpleString.toUpperCase})" - - case (v: Long, LongType) => - s"CAST($v AS ${LongType.simpleString.toUpperCase})" - - case (v: Float, FloatType) => - s"CAST($v AS ${FloatType.simpleString.toUpperCase})" - - case (v: Decimal, DecimalType.Fixed(precision, scale)) => - s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" - - case (v: Int, DateType) => - s"DATE '${DateTimeUtils.toJavaDate(v)}'" - - case (v: Long, TimestampType) => - s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" - + case (v: Byte, ByteType) => v + "Y" + case (v: Short, ShortType) => v + "S" + case (v: Long, LongType) => v + "L" + // Float type doesn't have a suffix + case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})" + case (v: Double, DoubleType) => v + "D" + case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})" + case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" + case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8b9a60f97c..bc2df0fb4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String) override def foldable: Boolean = true override def nullable: Boolean = false override def toString: String = s"$name()" + override def prettyName: String = name override def eval(input: InternalRow): Any = c } @@ -59,6 +60,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" + override def prettyName: String = name protected override def nullSafeEval(input: Any): Any = { f(input.asInstanceOf[Double]) @@ -70,8 +72,6 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } - - override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) @@ -113,6 +113,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def toString: String = s"$name($left, $right)" + override def prettyName: String = name + override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 207b8a0a88..1af5437647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ object NamedExpression { @@ -183,9 +184,9 @@ case class Alias(child: Expression, name: String)( override def sql: String = { val qualifiersString = - if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"${child.sql} AS $qualifiersString`$aliasName`" + s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}" } } @@ -300,9 +301,9 @@ case class AttributeReference( override def sql: String = { val qualifiersString = - if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"$qualifiersString`$attrRefName`" + s"$qualifiersString${quoteIdentifier(attrRefName)}" } } @@ -310,10 +311,19 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String, dataType: DataType = NullType) +case class PrettyAttribute( + name: String, + dataType: DataType = NullType) extends Attribute with Unevaluable { + def this(attribute: Attribute) = this(attribute.name, attribute match { + case a: AttributeReference => a.dataType + case a: PrettyAttribute => a.dataType + case _ => NullType + }) + override def toString: String = name + override def sql: String = toString override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 667d3513d3..e22026d584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,8 +83,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index fef6825b2d..737346dc79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -46,7 +46,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression { + propagateNull: Boolean = true) extends Expression with NonSQLExpression { val objectName = staticObject.getName.stripSuffix("$") @@ -108,7 +108,7 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression { + arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true override def children: Seq[Expression] = arguments.+:(targetObject) @@ -204,7 +204,7 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[Literal]) extends Expression { + outerPointer: Option[Literal]) extends Expression with NonSQLExpression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -268,7 +268,7 @@ case class NewInstance( */ case class UnwrapOption( dataType: DataType, - child: Expression) extends UnaryExpression with ExpectsInputTypes { + child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def nullable: Boolean = true @@ -298,7 +298,7 @@ case class UnwrapOption( * @param optType The type of this option. */ case class WrapOption(child: Expression, optType: DataType) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -328,7 +328,7 @@ case class WrapOption(child: Expression, optType: DataType) * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression - with Unevaluable { + with Unevaluable with NonSQLExpression { override def nullable: Boolean = true @@ -368,7 +368,7 @@ object MapObjects { case class MapObjects private( loopVar: LambdaVariable, lambdaFunction: Expression, - inputData: Expression) extends Expression { + inputData: Expression) extends Expression with NonSQLExpression { private def itemAccessorMethod(dataType: DataType): String => String = dataType match { case NullType => @@ -483,7 +483,7 @@ case class MapObjects private( * * @param children A list of expression to use as content of the external row. */ -case class CreateExternalRow(children: Seq[Expression]) extends Expression { +case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression { override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false @@ -516,7 +516,8 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) + extends UnaryExpression with NonSQLExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") @@ -558,7 +559,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression { + extends UnaryExpression with NonSQLExpression { override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. @@ -596,7 +597,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B * Initialize a Java Bean instance by setting its field values via setters. */ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) - extends Expression { + extends Expression with NonSQLExpression { override def nullable: Boolean = beanInstance.nullable override def children: Seq[Expression] = beanInstance +: setters.values.toSeq @@ -638,7 +639,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * non-null `s`, `s.i` can't be null. */ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) - extends UnaryExpression { + extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c290aa8825..20818bfb1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -249,6 +249,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with override def symbol: String = "&&" + override def sqlOperator: String = "AND" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == false) { @@ -289,8 +291,6 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } - - override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -300,6 +300,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P override def symbol: String = "||" + override def sqlOperator: String = "OR" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == true) { @@ -340,8 +342,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } - - override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b965212f27..4be065b30a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,7 +23,6 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -62,8 +61,6 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -156,8 +153,6 @@ case class ConcatWs(children: Seq[Expression]) """ } } - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index afe122f6a0..9e6bd0ee46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -556,8 +556,8 @@ abstract class RankLike extends AggregateWindowFunction { override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) /** Store the values of the window 'order' expressions. */ - protected val orderAttrs = children.map{ expr => - AttributeReference(expr.prettyString, expr.dataType)() + protected val orderAttrs = children.map { expr => + AttributeReference(expr.sql, expr.dataType)() } /** Predicate that detects if the order attributes have changed. */ @@ -636,7 +636,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { /** * The PercentRank function computes the percentage ranking of a value in a group of values. The - * result the rank of the minus one divided by the total number of rows in the partitiion minus one: + * result the rank of the minus one divided by the total number of rows in the partition minus one: * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. * * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7d155ac183..a19ba38ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -575,7 +575,7 @@ case class Pivot( override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 7a0d0de632..43f707f444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst import java.io._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NumericType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils package object util { @@ -130,20 +133,32 @@ package object util { ret } - /** - * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. - */ - def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { - case xs if xs.isEmpty => - Option(Seq.empty[T]) - - case xs => - for { - head <- xs.head - tail <- sequenceOption(xs.tail) - } yield head +: tail + // Replaces attributes, string literals, complex type extractors with their pretty form so that + // generated column names don't contain back-ticks or double-quotes. + def usePrettyExpression(e: Expression): Expression = e transform { + case a: Attribute => new PrettyAttribute(a) + case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType) + case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t) + case e: GetStructField => + val name = e.name.getOrElse(e.childSchema(e.ordinal).name) + PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType) + case e: GetArrayStructFields => + PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) } + def quoteIdentifier(name: String): String = { + // Escapes back-ticks within the identifier name with double-back-ticks, and then quote the + // identifier with back-ticks. + "`" + name.replace("`", "``") + "`" + } + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. + */ + def toCommentSafeString(str: String): String = + str.replace("*/", "\\*\\/").replace("\\u", "\\\\u") + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 5dd661ee6b..71ea5b8841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -64,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" + override def sql: String = typeName.toUpperCase + /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` * can be casted into `this` safely without losing any precision or range. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e797d83cb0..5ff5435d5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DataTypeParser, LegacyTypeStringParser} /** * :: DeveloperApi :: @@ -280,7 +280,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } override def sql: String = { - val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}") s"STRUCT<${fieldTypes.mkString(", ")}>" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 69f78a097e..de9a56dc9c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -127,22 +127,24 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 1" :: "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "invalid window function", @@ -207,7 +209,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by attributes are not from grouping expressions", testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc), - "cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil) + "cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil) errorTest( "non-boolean filters", @@ -222,7 +224,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "missing group by", testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil + "'`b`'" :: "group by" :: Nil ) errorTest( @@ -270,7 +272,7 @@ class AnalysisErrorSuite extends AnalysisTest { "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), - "cannot resolve 'bad_column'" :: Nil) + "cannot resolve '`bad_column`'" :: Nil) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) @@ -311,7 +313,7 @@ class AnalysisErrorSuite extends AnalysisTest { case true => assertAnalysisSuccess(plan, true) case false => - assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil) + assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil) } } @@ -372,7 +374,7 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -386,6 +388,6 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index af214b7af0..2756c463cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -71,8 +71,17 @@ trait AnalysisTest extends PlanTest { val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), - s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + - s"actually we get ${e.getMessage}") + + if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + fail( + s"""Exception message should contain the following substrings: + | + | ${expectedErrors.mkString("\n ")} + | + |Actual exception message: + | + | ${e.getMessage} + """.stripMargin) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 59549e3998..92c8496fde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -41,7 +41,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(expr) } assert(e.getMessage.contains( - s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + s"cannot resolve '${expr.sql}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}'") + s"differing types in '${expr.sql}'") } test("check types for unary arithmetic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 8b02b63c6c..3ad0dae767 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -142,7 +142,7 @@ class EncoderResolutionSuite extends PlanTest { }.message assert(msg2 == s""" - |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate |The type path of the target object is: |- field (class: "scala.Long", name: "b") |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2ab091e40a..6c7929c362 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DataTypeParser} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -104,10 +104,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => { + case _ if name.endsWith(".*") => val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) UnresolvedStar(Some(parts)) - } case _ => UnresolvedAttribute.quotedString(name) }) @@ -123,6 +122,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { // make it a NamedExpression. case u: UnresolvedAttribute => UnresolvedAlias(u) + case u: UnresolvedExtractValue => UnresolvedAlias(u) + case expr: NamedExpression => expr // Leave an unaliased generator with an empty list of names since the analyzer will generate @@ -131,7 +132,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. @@ -139,13 +140,13 @@ class Column(protected[sql] val expr: Expression) extends Logging { case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) } match { case ne: NamedExpression => ne - case other => Alias(expr, expr.prettyString)() + case other => Alias(expr, usePrettyExpression(expr).sql)() } - case expr: Expression => Alias(expr, expr.prettyString)() + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } - override def toString: String = expr.prettyString + override def toString: String = usePrettyExpression(expr).sql override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -987,7 +988,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { if (extended) { println(expr) } else { - println(expr.prettyString) + println(expr.sql) } // scalastyle:on println } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 76c09a285d..9674450118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator @@ -1359,7 +1360,8 @@ class DataFrame private[sql]( "min" -> ((child: Expression) => Min(child).toAggregateExpression()), "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + val outputCols = + (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c74ef2c035..66ec0e7338 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types.NumericType /** @@ -74,7 +75,7 @@ class GroupedData protected[sql]( private[this] def alias(expr: Expression): NamedExpression = expr match { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ad8564e96a..990eeb22b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} import org.apache.spark.sql.execution.metric.LongSQLMetricValue @@ -252,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: - * ${plan.treeString.trim} + * ${toCommentSafeString(plan.treeString.trim)} */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 812e696338..ab178443dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with Logging { + extends ImperativeAggregate with NonSQLExpression with Logging { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 1c0d53fc77..54dda0c391 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -544,7 +544,7 @@ private[parquet] object CatalystSchemaConverter { !name.matches(".*[ ,;{}()\n\t=].*"), s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) + """.stripMargin.split("\n").mkString(" ").trim) } def checkFieldNames(schema: StructType): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 0e53a0c473..9aff0be716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType /** @@ -36,9 +36,12 @@ case class PythonUDF( broadcastVars: java.util.List[Broadcast[PythonBroadcast]], accumulator: Accumulator[java.util.List[Array[Byte]]], dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with Logging { + children: Seq[Expression]) + extends Expression with Unevaluable with NonSQLExpression with Logging { - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})" override def nullable: Boolean = true + + override def sql: String = s"$name(${children.mkString(", ")})" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f9ba607700..498f007081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -525,7 +525,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.as[ClassData2] } - assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) + assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage) } test("runtime nullability check") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index bd87449f92..41a9404b00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -455,7 +455,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -479,7 +479,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index bf5edb4759..1dda39d44e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,10 +24,11 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation * supported by this builder (yet). */ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans") + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) try { + canonicalizedPlan.transformAllExpressions { + case e: NonSQLExpression => + throw new UnsupportedOperationException( + s"Expression $e doesn't have a SQL representation" + ) + case e => e + } + val generatedSQL = toSQL(canonicalizedPlan) logDebug( s"""Built SQL query string successfully from given logical plan: @@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi p.child match { // Persisted data source relation case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => - s"`$database`.`$table`" + s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" // Parentheses is not used for persisted data source relations // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => @@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: MetastoreRelation => build( - s"`${p.databaseName}`.`${p.tableName}`", - p.alias.map(a => s" AS `$a`").getOrElse("") + s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}", + p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("") ) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) @@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi * The segments are trimmed so only a single space appears in the separation. * For example, `build("a", " b ", " c")` becomes "a b c". */ - private def build(segments: String*): String = segments.map(_.trim).mkString(" ") + private def build(segments: String*): String = + segments.map(_.trim).filter(_.nonEmpty).mkString(" ") private def projectToSQL(plan: Project, isDistinct: Boolean): String = { build( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d5ed838ca4..bcafa045e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ @@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry( // catch the exception and throw AnalysisException instead. try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF( + val udf = HiveGenericUDF( name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction( + val udaf = HiveUDAFFunction( name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. @@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - override val dataType = javaClassToDataType(method.getReturnType) + override lazy val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + override def prettyName: String = name + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } @@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF( } @transient - private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF( var i = 0 while (i < children.length) { val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set( () => { children(idx).eval(input) }) i += 1 } - unwrap(function.evaluate(deferedObjects), returnInspector) + unwrap(function.evaluate(deferredObjects), returnInspector) } + override def prettyName: String = name + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" + override def prettyName: String = name } /** @@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def prettyName: String = name override def sql(isDistinct: Boolean): String = { val distinct = if (isDistinct) "DISTINCT " else " " diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 3fb6543b1a..e4b4d1861a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { checkSQL(Literal("foo"), "\"foo\"") checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") - checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") - checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(1: Byte), "1Y") + checkSQL(Literal(2: Short), "2S") checkSQL(Literal(4: Int), "4") - checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(8: Long), "8L") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") - checkSQL(Literal(2.5D), "2.5") + checkSQL(Literal(2.5D), "2.5D") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } + test("attributes") { + checkSQL('a.int, "`a`") + checkSQL(Symbol("foo bar").int, "`foo bar`") + // Keyword + checkSQL('int.int, "`int`") + } + test("binary comparisons") { checkSQL('a.int === 'b.int, "(`a` = `b`)") checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index dc8ac7e47f..5255b150aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS t0") sql("DROP TABLE IF EXISTS t1") sql("DROP TABLE IF EXISTS t2") + sqlContext.range(10).write.saveAsTable("t0") sqlContext @@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } } } + + test("plans with non-SQL expressions") { + sqlContext.udf.register("foo", (_: Int) * 2) + intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) + } + + test("named expression in column names shouldn't be quoted") { + def checkColumnNames(query: String, expectedColNames: String*): Unit = { + checkHiveQl(query) + assert(sql(query).columns === expectedColNames) + } + + // Attributes + checkColumnNames( + """SELECT * FROM ( + | SELECT 1 AS a, 2 AS b, 3 AS `we``ird` + |) s + """.stripMargin, + "a", "b", "we`ird" + ) + + checkColumnNames( + """SELECT x.a, y.a, x.b, y.b + |FROM (SELECT 1 AS a, 2 AS b) x + |INNER JOIN (SELECT 1 AS a, 2 AS b) y + |ON x.a = y.a + """.stripMargin, + "a", "a", "b", "b" + ) + + // String literal + checkColumnNames( + "SELECT 'foo', '\"bar\\''", + "foo", "\"bar\'" + ) + + // Numeric literals (should have CAST or suffixes in column names) + checkColumnNames( + "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D", + "1", "2", "3", "4", "5.1", "6.1" + ) + + // Aliases + checkColumnNames( + "SELECT 1 AS a", + "a" + ) + + // Complex type extractors + checkColumnNames( + """SELECT + | a.f1, b[0].f1, b.f1, c["foo"], d[0] + |FROM ( + | SELECT + | NAMED_STRUCT("f1", 1, "f2", "foo") AS a, + | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b, + | MAP("foo", 1) AS c, + | ARRAY(1) AS d + |) s + """.stripMargin, + "f1", "b[0].f1", "f1", "c[foo]", "d[0]" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd055f9eca..d8d3448add 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton { test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 27ea3e8041..fe446774ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val df = sql( """ |SELECT - | CAST(null as TINYINT), - | CAST(null as SMALLINT), - | CAST(null as INT), - | CAST(null as BIGINT), - | CAST(null as FLOAT), - | CAST(null as DOUBLE), - | CAST(null as DECIMAL(7,2)), - | CAST(null as TIMESTAMP), - | CAST(null as DATE), - | CAST(null as STRING), - | CAST(null as VARCHAR(10)) + | CAST(null as TINYINT) as c0, + | CAST(null as SMALLINT) as c1, + | CAST(null as INT) as c2, + | CAST(null as BIGINT) as c3, + | CAST(null as FLOAT) as c4, + | CAST(null as DOUBLE) as c5, + | CAST(null as DECIMAL(7,2)) as c6, + | CAST(null as TIMESTAMP) as c7, + | CAST(null as DATE) as c8, + | CAST(null as STRING) as c9, + | CAST(null as VARCHAR(10)) as c10 |FROM orc_temp_table limit 1 """.stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c997453803..a6ca7d0386 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' - |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + |AS SELECT + | '1st' AS a, + | '2nd' AS b, + | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c """.stripMargin) checkAnswer(