From e352de0db2789919e1e0385b79f29b508a6b2b77 Mon Sep 17 00:00:00 2001 From: Nong Date: Tue, 3 Nov 2015 16:44:37 -0800 Subject: [PATCH] [SPARK-11329] [SQL] Cleanup from spark-11329 fix. Author: Nong Closes #9442 from nongli/spark-11483. --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 +- .../sql/catalyst/analysis/unresolved.scala | 18 +---- .../scala/org/apache/spark/sql/Column.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 77 +++++++++++-------- 4 files changed, 54 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 1ba559d9e3..440e9e28fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -477,8 +477,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target)) } - } | primary + | (ident <~ "."). + <~ "*" ^^ { case target => UnresolvedStar(Option(target))} + | primary ) protected lazy val signedPrimary: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6975662e2b..eae17c86dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -183,28 +183,16 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case None => input.output // If there is a table, pick out attributes that are part of this table. case Some(t) => if (t.size == 1) { - input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty) + input.output.filter(_.qualifiers.exists(resolver(_, t.head))) } else { List() } } - if (!expandedAttributes.isEmpty) { - if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) { - return expandedAttributes - } else { - require(expandedAttributes.size == input.output.size) - expandedAttributes.zip(input.output).map { - case (e, originalAttribute) => - Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) - } - } - return expandedAttributes - } - - require(target.isDefined) + if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. + require(target.isDefined) val attribute = input.resolve(target.get, resolver) if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3cde9d6cb4..c73f696962 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -60,8 +60,10 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName( - name.substring(0, name.length - 2)))) + case _ if name.endsWith(".*") => { + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + } case _ => UnresolvedAttribute.quotedString(name) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ee54bff24b..6388a8b9c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct} +import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -1956,7 +1956,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // Try with a registered table. sql("select struct(a, b) as record from testData2").registerTempTable("structTable") - checkAnswer(sql("SELECT record.* FROM structTable"), + checkAnswer( + sql("SELECT record.* FROM structTable"), Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) checkAnswer(sql( @@ -2019,50 +2020,62 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) // Try with a registered table - nestedStructData.registerTempTable("nestedStructTable") - checkAnswer(sql("SELECT record.* FROM nestedStructTable"), - nestedStructData.select($"record.*")) - checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"), - nestedStructData.select($"record.r1")) - checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"), - nestedStructData.select($"record.r1.*")) + withTempTable("nestedStructTable") { + nestedStructData.registerTempTable("nestedStructTable") + checkAnswer( + sql("SELECT record.* FROM nestedStructTable"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT record.r1 FROM nestedStructTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT record.r1.* FROM nestedStructTable"), + nestedStructData.select($"record.r1.*")) - // Create paths with unusual characters. + // Try resolving something not there. + assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) + .getMessage.contains("cannot resolve")) + } + + // Create paths with unusual characters val specialCharacterPath = sql( """ | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) - specialCharacterPath.registerTempTable("specialCharacterTable") - checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"), - nestedStructData.select($"record.*")) - checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), - nestedStructData.select($"record.r1")) - checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), - nestedStructData.select($"record.r2")) - checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), - nestedStructData.select($"record.r1.*")) + withTempTable("specialCharacterTable") { + specialCharacterPath.registerTempTable("specialCharacterTable") + checkAnswer( + specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + } // Try star expanding a scalar. This should fail. assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains( "Can only star expand struct data types.")) - - // Try resolving something not there. - assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) - .getMessage.contains("cannot resolve")) } - test("Struct Star Expansion - Name conflict") { // Create a data set that contains a naming conflict val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") - nameConflict.registerTempTable("nameConflict") - // Unqualified should resolve to table. - checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), - Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: - Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) - // Qualify the struct type with the table name. - checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), - Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + withTempTable("nameConflict") { + nameConflict.registerTempTable("nameConflict") + // Unqualified should resolve to table. + checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), + Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: + Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) + // Qualify the struct type with the table name. + checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } } }